Merge branch 'main' of github.com:langgenius/dify into feat/no-root-image

This commit is contained in:
Byron Wang 2025-11-13 13:56:42 +08:00
commit 4e201ef059
No known key found for this signature in database
GPG Key ID: 335E934E215AD579
2300 changed files with 132608 additions and 29325 deletions

View File

@ -1,4 +1,4 @@
FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev && apt-get -y install libgmp-dev libmpfr-dev libmpc-dev

View File

@ -11,7 +11,7 @@
"nodeGypDependencies": true, "nodeGypDependencies": true,
"version": "lts" "version": "lts"
}, },
"ghcr.io/devcontainers-contrib/features/npm-package:1": { "ghcr.io/devcontainers-extra/features/npm-package:1": {
"package": "typescript", "package": "typescript",
"version": "latest" "version": "latest"
}, },

View File

@ -6,7 +6,7 @@ cd web && pnpm install
pipx install uv pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc

View File

@ -39,25 +39,11 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: uv sync --project api --dev run: uv sync --project api --dev
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Run pyrefly check - name: Run pyrefly check
run: | run: |
cd api cd api
uv add --dev pyrefly uv add --dev pyrefly
uv run pyrefly check || true uv run pyrefly check || true
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
- name: Run dify config tests - name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py run: uv run --project api dev/pytest/pytest_config_tests.py
@ -93,3 +79,19 @@ jobs:
- name: Run TestContainers - name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY

View File

@ -2,6 +2,8 @@ name: autofix.ci
on: on:
pull_request: pull_request:
branches: ["main"] branches: ["main"]
push:
branches: ["main"]
permissions: permissions:
contents: read contents: read
@ -30,6 +32,8 @@ jobs:
run: | run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
# Convert Optional[T] to T | None (ignoring quoted types) # Convert Optional[T] to T | None (ignoring quoted types)
cat > /tmp/optional-rule.yml << 'EOF' cat > /tmp/optional-rule.yml << 'EOF'
id: convert-optional-to-union id: convert-optional-to-union

View File

@ -4,8 +4,7 @@ on:
push: push:
branches: branches:
- "main" - "main"
- "deploy/dev" - "deploy/**"
- "deploy/enterprise"
- "build/**" - "build/**"
- "release/e-*" - "release/e-*"
- "hotfix/**" - "hotfix/**"

View File

@ -18,7 +18,7 @@ jobs:
- name: Deploy to server - name: Deploy to server
uses: appleboy/ssh-action@v0.1.8 uses: appleboy/ssh-action@v0.1.8
with: with:
host: ${{ secrets.RAG_SSH_HOST }} host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }} username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }} key: ${{ secrets.SSH_PRIVATE_KEY }}
script: | script: |

View File

@ -1,4 +1,4 @@
name: Deploy RAG Dev name: Deploy Trigger Dev
permissions: permissions:
contents: read contents: read
@ -7,7 +7,7 @@ on:
workflow_run: workflow_run:
workflows: ["Build and Push API & Web"] workflows: ["Build and Push API & Web"]
branches: branches:
- "deploy/rag-dev" - "deploy/trigger-dev"
types: types:
- completed - completed
@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: | if: |
github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/rag-dev' github.event.workflow_run.head_branch == 'deploy/trigger-dev'
steps: steps:
- name: Deploy to server - name: Deploy to server
uses: appleboy/ssh-action@v0.1.8 uses: appleboy/ssh-action@v0.1.8
with: with:
host: ${{ secrets.RAG_SSH_HOST }} host: ${{ secrets.TRIGGER_SSH_HOST }}
username: ${{ secrets.SSH_USER }} username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }} key: ${{ secrets.SSH_PRIVATE_KEY }}
script: | script: |

View File

@ -1,6 +1,7 @@
#!/bin/bash #!/bin/bash
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
@ -13,4 +14,4 @@ yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.ya
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss" echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"

View File

@ -103,6 +103,11 @@ jobs:
run: | run: |
pnpm run lint pnpm run lint
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run type-check
docker-compose-template: docker-compose-template:
name: Docker Compose Template name: Docker Compose Template
runs-on: ubuntu-latest runs-on: ubuntu-latest

9
.gitignore vendored
View File

@ -6,6 +6,9 @@ __pycache__/
# C extensions # C extensions
*.so *.so
# *db files
*.db
# Distribution / packaging # Distribution / packaging
.Python .Python
build/ build/
@ -97,6 +100,7 @@ __pypackages__/
# Celery stuff # Celery stuff
celerybeat-schedule celerybeat-schedule
celerybeat-schedule.db
celerybeat.pid celerybeat.pid
# SageMath parsed files # SageMath parsed files
@ -234,4 +238,7 @@ scripts/stress-test/reports/
# mcp # mcp
.playwright-mcp/ .playwright-mcp/
.serena/ .serena/
# settings
*.local.json

View File

@ -8,8 +8,7 @@
"module": "flask", "module": "flask",
"env": { "env": {
"FLASK_APP": "app.py", "FLASK_APP": "app.py",
"FLASK_ENV": "development", "FLASK_ENV": "development"
"GEVENT_SUPPORT": "True"
}, },
"args": [ "args": [
"run", "run",
@ -28,9 +27,7 @@
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"module": "celery", "module": "celery",
"env": { "env": {},
"GEVENT_SUPPORT": "True"
},
"args": [ "args": [
"-A", "-A",
"app.celery", "app.celery",
@ -40,7 +37,7 @@
"-c", "-c",
"1", "1",
"-Q", "-Q",
"dataset,generation,mail,ops_trace", "dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline",
"--loglevel", "--loglevel",
"INFO" "INFO"
], ],

View File

@ -14,7 +14,7 @@ The codebase is split into:
- Run backend CLI commands through `uv run --project api <command>`. - Run backend CLI commands through `uv run --project api <command>`.
- Backend QA gate requires passing `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before review. - Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks. - Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.

View File

@ -26,7 +26,6 @@ prepare-web:
@echo "🌐 Setting up web environment..." @echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists" @cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install @cd web && pnpm install
@cd web && pnpm build
@echo "✅ Web environment prepared (not started)" @echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment # Step 3: Prepare API environment

View File

@ -40,18 +40,18 @@
<p align="center"> <p align="center">
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a> <a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
<a href="./README/README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a> <a href="./docs/zh-TW/README.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
<a href="./README/README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a> <a href="./docs/zh-CN/README.md"><img alt="简体中文文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./README/README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a> <a href="./docs/ja-JP/README.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./README/README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a> <a href="./docs/es-ES/README.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./README/README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a> <a href="./docs/fr-FR/README.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./README/README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a> <a href="./docs/tlh/README.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
<a href="./README/README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a> <a href="./docs/ko-KR/README.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
<a href="./README/README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a> <a href="./docs/ar-SA/README.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
<a href="./README/README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a> <a href="./docs/tr-TR/README.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
<a href="./README/README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a> <a href="./docs/vi-VN/README.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
<a href="./README/README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a> <a href="./docs/de-DE/README.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
<a href="./README/README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a> <a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
</p> </p>
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production. Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
@ -63,7 +63,7 @@ Dify is an open-source platform for developing LLM applications. Its intuitive i
> - CPU >= 2 Core > - CPU >= 2 Core
> - RAM >= 4 GiB > - RAM >= 4 GiB
</br> <br/>
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
@ -109,15 +109,15 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
## Using Dify ## Using Dify
- **Cloud </br>** - **Cloud <br/>**
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan. We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
- **Self-hosting Dify Community Edition</br>** - **Self-hosting Dify Community Edition<br/>**
Quickly get Dify running in your environment with this [starter guide](#quick-start). Quickly get Dify running in your environment with this [starter guide](#quick-start).
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions. Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
- **Dify for enterprise / organizations</br>** - **Dify for enterprise / organizations<br/>**
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. </br> We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding. > For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
@ -129,8 +129,18 @@ Star Dify on GitHub and be instantly notified of new releases.
## Advanced Setup ## Advanced Setup
### Custom configurations
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
### Metrics Monitoring with Grafana
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
- [Grafana Dashboard by @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard)
### Deployment with Kubernetes
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)

View File

@ -27,6 +27,9 @@ FILES_URL=http://localhost:5001
# Example: INTERNAL_FILES_URL=http://api:5001 # Example: INTERNAL_FILES_URL=http://api:5001
INTERNAL_FILES_URL=http://127.0.0.1:5001 INTERNAL_FILES_URL=http://127.0.0.1:5001
# TRIGGER URL
TRIGGER_URL=http://localhost:5001
# The time in seconds after the signature is rejected # The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300 FILES_ACCESS_TIMEOUT=300
@ -156,6 +159,9 @@ SUPABASE_URL=your-server-url
# CORS configuration # CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,* WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Set COOKIE_DOMAIN when the console frontend and API are on different subdomains.
# Provide the registrable domain (e.g. example.com); leading dots are optional.
COOKIE_DOMAIN=
# Vector database configuration # Vector database configuration
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. # Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
@ -343,6 +349,15 @@ OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false OCEANBASE_ENABLE_HYBRID_SEARCH=false
# AlibabaCloud MySQL Vector configuration
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
ALIBABACLOUD_MYSQL_PORT=3306
ALIBABACLOUD_MYSQL_USER=root
ALIBABACLOUD_MYSQL_PASSWORD=root
ALIBABACLOUD_MYSQL_DATABASE=dify
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
ALIBABACLOUD_MYSQL_HNSW_M=6
# openGauss configuration # openGauss configuration
OPENGAUSS_HOST=127.0.0.1 OPENGAUSS_HOST=127.0.0.1
OPENGAUSS_PORT=6600 OPENGAUSS_PORT=6600
@ -359,6 +374,12 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Comma-separated list of file extensions blocked from upload for security reasons.
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
# Empty by default to allow all file types.
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
UPLOAD_FILE_EXTENSION_BLACKLIST=
# Model configuration # Model configuration
MULTIMODAL_SEND_FORMAT=base64 MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512 PROMPT_GENERATION_MAX_TOKENS=512
@ -425,10 +446,13 @@ CODE_EXECUTION_SSL_VERIFY=True
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100 CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20 CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0 CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
CODE_EXECUTION_CONNECT_TIMEOUT=10
CODE_EXECUTION_READ_TIMEOUT=60
CODE_EXECUTION_WRITE_TIMEOUT=10
CODE_MAX_NUMBER=9223372036854775807 CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808 CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_STRING_LENGTH=80000 CODE_MAX_STRING_LENGTH=400000
TEMPLATE_TRANSFORM_MAX_LENGTH=80000 TEMPLATE_TRANSFORM_MAX_LENGTH=400000
CODE_MAX_STRING_ARRAY_LENGTH=30 CODE_MAX_STRING_ARRAY_LENGTH=30
CODE_MAX_OBJECT_ARRAY_LENGTH=30 CODE_MAX_OBJECT_ARRAY_LENGTH=30
CODE_MAX_NUMBER_ARRAY_LENGTH=1000 CODE_MAX_NUMBER_ARRAY_LENGTH=1000
@ -445,6 +469,9 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
HTTP_REQUEST_NODE_SSL_VERIFY=True HTTP_REQUEST_NODE_SSL_VERIFY=True
# Webhook request configuration
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
# Respect X-* headers to redirect clients # Respect X-* headers to redirect clients
RESPECT_XFORWARD_HEADERS_ENABLED=false RESPECT_XFORWARD_HEADERS_ENABLED=false
@ -500,7 +527,7 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# Workflow log cleanup configuration # Workflow log cleanup configuration
# Enable automatic cleanup of workflow run logs to manage database size # Enable automatic cleanup of workflow run logs to manage database size
WORKFLOW_LOG_CLEANUP_ENABLED=true WORKFLOW_LOG_CLEANUP_ENABLED=false
# Number of days to retain workflow run logs (default: 30 days) # Number of days to retain workflow run logs (default: 30 days)
WORKFLOW_LOG_RETENTION_DAYS=30 WORKFLOW_LOG_RETENTION_DAYS=30
# Batch size for workflow log cleanup operations (default: 100) # Batch size for workflow log cleanup operations (default: 100)
@ -522,6 +549,12 @@ ENABLE_CLEAN_MESSAGES=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
# Interval time in minutes for polling scheduled workflows(default: 1 min)
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
# Position configuration # Position configuration
POSITION_TOOL_PINS= POSITION_TOOL_PINS=
@ -593,3 +626,9 @@ SWAGGER_UI_PATH=/swagger-ui.html
# Whether to encrypt dataset IDs when exporting DSL files (default: true) # Whether to encrypt dataset IDs when exporting DSL files (default: true)
# Set to false to export dataset IDs as plain text for easier cross-environment import # Set to false to export dataset IDs as plain text for easier cross-environment import
DSL_EXPORT_ENCRYPT_DATASET_ID=true DSL_EXPORT_ENCRYPT_DATASET_ID=true
# Tenant isolated task queue configuration
TENANT_ISOLATED_TASK_CONCURRENCY=1
# Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0

View File

@ -81,7 +81,6 @@ ignore = [
"SIM113", # enumerate-for-loop "SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements "SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false "SIM210", # if-expr-with-true-false
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
] ]
[lint.per-file-ignores] [lint.per-file-ignores]

View File

@ -54,7 +54,7 @@
"--loglevel", "--loglevel",
"DEBUG", "DEBUG",
"-Q", "-Q",
"dataset,generation,mail,ops_trace,app_deletion" "dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
] ]
} }
] ]

62
api/AGENTS.md Normal file
View File

@ -0,0 +1,62 @@
# Agent Skill Index
Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
______________________________________________________________________
## Platform Foundations
- **[Infrastructure Overview](agent_skills/infra.md)**\
When to read this:
- You need to understand where a feature belongs in the architecture.
- Youre wiring storage, Redis, vector stores, or OTEL.
- Youre about to add CLI commands or async jobs.\
What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands.
- **[Coding Style](agent_skills/coding_style.md)**\
When to read this:
- Youre writing or reviewing backend code and need the authoritative checklist.
- Youre unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns.
- You want the exact lint/type/test commands used in PRs.\
Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management.
______________________________________________________________________
## Plugin & Extension Development
- **[Plugin Systems](agent_skills/plugin.md)**\
When to read this:
- Youre building or debugging a marketplace plugin.
- You need to know how manifests, providers, daemons, and migrations fit together.\
What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform.
- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\
When to read this:
- You must integrate OAuth for a plugin or datasource.
- Youre handling credential encryption or refresh flows.\
Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows.
______________________________________________________________________
## Workflow Entry & Execution
- **[Trigger Concepts](agent_skills/trigger.md)**\
When to read this:
- Youre debugging why a workflow didnt start.
- Youre adding a new trigger type or hook.
- You need to trace async execution, draft debugging, or webhook/schedule pipelines.\
Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions.
______________________________________________________________________
## Additional Notes for Agents
- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes.
- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.

View File

@ -15,7 +15,11 @@ FROM base AS packages
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources # RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev && apt-get install -y --no-install-recommends \
# basic environment
g++ \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies # Install Python dependencies
COPY pyproject.toml uv.lock ./ COPY pyproject.toml uv.lock ./
@ -49,7 +53,9 @@ RUN \
# Install dependencies # Install dependencies
&& apt-get install -y --no-install-recommends \ && apt-get install -y --no-install-recommends \
# basic environment # basic environment
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \ curl nodejs \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \
# For Security # For Security
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \ expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
# install fonts to support the use of tools like pypdfium2 # install fonts to support the use of tools like pypdfium2

View File

@ -80,7 +80,7 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash ```bash
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
``` ```
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service: Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:

View File

@ -0,0 +1,115 @@
## Linter
- Always follow `.ruff.toml`.
- Run `uv run ruff check --fix --unsafe-fixes`.
- Keep each line under 100 characters (including spaces).
## Code Style
- `snake_case` for variables and functions.
- `PascalCase` for classes.
- `UPPER_CASE` for constants.
## Rules
- Use Pydantic v2 standard.
- Use `uv` for package management.
- Do not override dunder methods like `__init__`, `__iadd__`, etc.
- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed.
- Prefer simple functions over classes for lightweight helpers.
- Keep files below 800 lines; split when necessary.
- Keep code readable—no clever hacks.
- Never use `print`; log with `logger = logging.getLogger(__name__)`.
## Guiding Principles
- Mirror the projects layered architecture: controller → service → core/domain.
- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
- Optimise for observability: deterministic control flow, clear logging, actionable errors.
## SQLAlchemy Patterns
- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines.
- Open sessions with context managers:
```python
from sqlalchemy.orm import Session
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Workflow).where(
Workflow.id == workflow_id,
Workflow.tenant_id == tenant_id,
)
workflow = session.execute(stmt).scalar_one_or_none()
```
- Use SQLAlchemy expressions; avoid raw SQL unless necessary.
- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies.
- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
## Storage & External IO
- Access storage via `extensions.ext_storage.storage`.
- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
- Background tasks that touch storage must be idempotent and log the relevant object identifiers.
## Pydantic Usage
- Define DTOs with Pydantic v2 models and forbid extras by default.
- Use `@field_validator` / `@model_validator` for domain rules.
- Example:
```python
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
class TriggerConfig(BaseModel):
endpoint: HttpUrl
secret: str
model_config = ConfigDict(extra="forbid")
@field_validator("secret")
def ensure_secret_prefix(cls, value: str) -> str:
if not value.startswith("dify_"):
raise ValueError("secret must start with dify_")
return value
```
## Generics & Protocols
- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
## Error Handling & Logging
- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers.
- Declare `logger = logging.getLogger(__name__)` at module top.
- Include tenant/app/workflow identifiers in log context.
- Log retryable events at `warning`, terminal failures at `error`.
## Tooling & Checks
- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`.
- Type checks: `uv run --directory api --dev basedpyright`.
- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
- Run all of the above before submitting your work.
## Controllers & Services
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables.
- Document non-obvious behaviour with concise comments.
## Miscellaneous
- Use `configs.dify_config` for configuration—never read environment variables directly.
- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
- Keep experimental scripts under `dev/`; do not ship them in production builds.

96
api/agent_skills/infra.md Normal file
View File

@ -0,0 +1,96 @@
## Configuration
- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly.
- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`.
- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing.
- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`.
## Dependencies
- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`.
- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group.
- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current.
## Storage & Files
- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend.
- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads.
- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly.
- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform.
## Redis & Shared State
- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`.
- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`.
## Models
- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`).
- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn.
- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories.
- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below.
## Vector Stores
- Vector client implementations live in `core/rag/datasource/vdb/<provider>`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`.
- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`.
- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions.
- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations.
## Observability & OTEL
- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads.
- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints.
- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`).
- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`.
## Ops Integrations
- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above.
- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules.
- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata.
## Controllers, Services, Core
- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`.
- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs).
- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`.
## Plugins, Tools, Providers
- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation.
- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`.
- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way.
- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application.
- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config).
- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly.
## Async Workloads
see `agent_skills/trigger.md` for more detailed documentation.
- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`.
- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc.
- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs.
## Database & Migrations
- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`.
- Generate migrations with `uv run --project api flask db revision --autogenerate -m "<summary>"`, then review the diff; never hand-edit the database outside Alembic.
- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history.
- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables.
## CLI Commands
- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask <command>`.
- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour.
- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations.
- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR.
- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes).
## When You Add Features
- Check for an existing helper or service before writing a new util.
- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`.
- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations).
- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes.

View File

@ -0,0 +1 @@
// TBD

View File

@ -0,0 +1 @@
// TBD

View File

@ -0,0 +1,53 @@
## Overview
Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node.
## Trigger nodes
- `UserInput`
- `Trigger Webhook`
- `Trigger Schedule`
- `Trigger Plugin`
### UserInput
Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app`
1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool.
1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node.
1. For its detailed implementation, please refer to `core/workflow/nodes/start`
### Trigger Webhook
Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`.
Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution.
### Trigger Schedule
`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help.
To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published.
### Trigger Plugin
`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it.
1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint`
1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details.
A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one.
## Worker Pool / Async Task
All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`.
The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`.
## Debug Strategy
Dify divided users into 2 groups: builders / end users.
Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`.
A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type.

View File

@ -1,7 +1,7 @@
import sys import sys
def is_db_command(): def is_db_command() -> bool:
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db": if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
return True return True
return False return False
@ -13,23 +13,12 @@ if is_db_command():
app = create_migrations_app() app = create_migrations_app()
else: else:
# It seems that JetBrains Python debugger does not work well with gevent, # Gunicorn and Celery handle monkey patching automatically in production by
# so we need to disable gevent in debug mode. # specifying the `gevent` worker class. Manual monkey patching is not required here.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
# from gevent import monkey
# #
# # gevent # See `api/docker/entrypoint.sh` (lines 33 and 47) for details.
# monkey.patch_all()
# #
# from grpc.experimental import gevent as grpc_gevent # type: ignore # For third-party library patching, refer to `gunicorn.conf.py` and `celery_entrypoint.py`.
#
# # grpc gevent
# grpc_gevent.init_gevent()
# import psycogreen.gevent # type: ignore
#
# psycogreen.gevent.patch_psycopg()
from app_factory import create_app from app_factory import create_app

View File

@ -15,12 +15,12 @@ from sqlalchemy.orm import sessionmaker
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from core.helper import encrypter from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import Document from core.rag.models.document import Document
from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created from events.app_event import app_was_created
from extensions.ext_database import db from extensions.ext_database import db
@ -321,6 +321,8 @@ def migrate_knowledge_vector_database():
) )
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
if not datasets.items:
break
except SQLAlchemyError: except SQLAlchemyError:
raise raise
@ -1227,6 +1229,55 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_system_trigger_oauth_client(provider, client_params):
"""
Setup system trigger oauth client
"""
from models.provider_ids import TriggerProviderID
from models.trigger import TriggerOAuthSystemClient
provider_id = TriggerProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(TriggerOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
oauth_client = TriggerOAuthSystemClient(
provider=provider_name,
plugin_id=plugin_id,
encrypted_oauth_params=oauth_client_params,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]: def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
""" """
Find draft variables that reference non-existent apps. Find draft variables that reference non-existent apps.
@ -1420,7 +1471,10 @@ def setup_datasource_oauth_client(provider, client_params):
@click.command("transform-datasource-credentials", help="Transform datasource credentials.") @click.command("transform-datasource-credentials", help="Transform datasource credentials.")
def transform_datasource_credentials(): @click.option(
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
)
def transform_datasource_credentials(environment: str):
""" """
Transform datasource credentials Transform datasource credentials
""" """
@ -1431,9 +1485,14 @@ def transform_datasource_credentials():
notion_plugin_id = "langgenius/notion_datasource" notion_plugin_id = "langgenius/notion_datasource"
firecrawl_plugin_id = "langgenius/firecrawl_datasource" firecrawl_plugin_id = "langgenius/firecrawl_datasource"
jina_plugin_id = "langgenius/jina_datasource" jina_plugin_id = "langgenius/jina_datasource"
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage] if environment == "online":
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage] notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage] firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
else:
notion_plugin_unique_identifier = None
firecrawl_plugin_unique_identifier = None
jina_plugin_unique_identifier = None
oauth_credential_type = CredentialType.OAUTH2 oauth_credential_type = CredentialType.OAUTH2
api_key_credential_type = CredentialType.API_KEY api_key_credential_type = CredentialType.API_KEY
@ -1521,6 +1580,14 @@ def transform_datasource_credentials():
auth_count = 0 auth_count = 0
for firecrawl_tenant_credential in firecrawl_tenant_credentials: for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1 auth_count += 1
if not firecrawl_tenant_credential.credentials:
click.echo(
click.style(
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
fg="yellow",
)
)
continue
# get credential api key # get credential api key
credentials_json = json.loads(firecrawl_tenant_credential.credentials) credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key") api_key = credentials_json.get("config", {}).get("api_key")
@ -1576,6 +1643,14 @@ def transform_datasource_credentials():
auth_count = 0 auth_count = 0
for jina_tenant_credential in jina_tenant_credentials: for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1 auth_count += 1
if not jina_tenant_credential.credentials:
click.echo(
click.style(
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
fg="yellow",
)
)
continue
# get credential api key # get credential api key
credentials_json = json.loads(jina_tenant_credential.credentials) credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key") api_key = credentials_json.get("config", {}).get("api_key")
@ -1583,7 +1658,7 @@ def transform_datasource_credentials():
"integration_secret": api_key, "integration_secret": api_key,
} }
datasource_provider = DatasourceProvider( datasource_provider = DatasourceProvider(
provider="jina", provider="jinareader",
tenant_id=tenant_id, tenant_id=tenant_id,
plugin_id=jina_plugin_id, plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value, auth_type=api_key_credential_type.value,

View File

@ -150,7 +150,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
CODE_MAX_STRING_LENGTH: PositiveInt = Field( CODE_MAX_STRING_LENGTH: PositiveInt = Field(
description="Maximum allowed length for strings in code execution", description="Maximum allowed length for strings in code execution",
default=80000, default=400_000,
) )
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field( CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
@ -174,6 +174,33 @@ class CodeExecutionSandboxConfig(BaseSettings):
) )
class TriggerConfig(BaseSettings):
"""
Configuration for trigger
"""
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
description="Maximum allowed size for webhook request bodies in bytes",
default=10485760,
)
class AsyncWorkflowConfig(BaseSettings):
"""
Configuration for async workflow
"""
ASYNC_WORKFLOW_SCHEDULER_GRANULARITY: int = Field(
description="Granularity for async workflow scheduler, "
"sometime, few users could block the queue due to some time-consuming tasks, "
"to avoid this, workflow can be suspended if needed, to achieve"
"this, a time-based checker is required, every granularity seconds, "
"the checker will check the workflow queue and suspend the workflow",
default=120,
ge=1,
)
class PluginConfig(BaseSettings): class PluginConfig(BaseSettings):
""" """
Plugin configs Plugin configs
@ -189,6 +216,11 @@ class PluginConfig(BaseSettings):
default="plugin-api-key", default="plugin-api-key",
) )
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=300.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key") INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
PLUGIN_REMOTE_INSTALL_HOST: str = Field( PLUGIN_REMOTE_INSTALL_HOST: str = Field(
@ -258,6 +290,8 @@ class EndpointConfig(BaseSettings):
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}" description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
) )
TRIGGER_URL: str = Field(description="Template url for triggers", default="http://localhost:5001")
class FileAccessConfig(BaseSettings): class FileAccessConfig(BaseSettings):
""" """
@ -326,12 +360,42 @@ class FileUploadConfig(BaseSettings):
default=10, default=10,
) )
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
"Empty by default to allow all file types."
),
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
default="",
)
@computed_field # type: ignore[misc]
@property
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
"""
Parse and return the blacklist as a set of lowercase extensions.
Returns an empty set if no blacklist is configured.
"""
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
return set()
return {
ext.strip().lower().strip(".")
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
if ext.strip()
}
class HttpConfig(BaseSettings): class HttpConfig(BaseSettings):
""" """
HTTP-related configurations for the application HTTP-related configurations for the application
""" """
COOKIE_DOMAIN: str = Field(
description="Explicit cookie domain for console/service cookies when sharing across subdomains",
default="",
)
API_COMPRESSION_ENABLED: bool = Field( API_COMPRESSION_ENABLED: bool = Field(
description="Enable or disable gzip compression for HTTP responses", description="Enable or disable gzip compression for HTTP responses",
default=False, default=False,
@ -362,11 +426,11 @@ class HttpConfig(BaseSettings):
) )
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field( HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60 ge=1, description="Maximum read timeout in seconds for HTTP requests", default=600
) )
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field( HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20 ge=1, description="Maximum write timeout in seconds for HTTP requests", default=600
) )
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
@ -543,7 +607,7 @@ class UpdateConfig(BaseSettings):
class WorkflowVariableTruncationConfig(BaseSettings): class WorkflowVariableTruncationConfig(BaseSettings):
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field( WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
# 100KB # 1000 KiB
1024_000, 1024_000,
description="Maximum size for variable to trigger final truncation.", description="Maximum size for variable to trigger final truncation.",
) )
@ -582,6 +646,11 @@ class WorkflowConfig(BaseSettings):
default=200 * 1024, default=200 * 1024,
) )
TEMPLATE_TRANSFORM_MAX_LENGTH: PositiveInt = Field(
description="Maximum number of characters allowed in Template Transform node output",
default=400_000,
)
# GraphEngine Worker Pool Configuration # GraphEngine Worker Pool Configuration
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field( GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
description="Minimum number of workers per GraphEngine instance", description="Minimum number of workers per GraphEngine instance",
@ -766,7 +835,7 @@ class MailConfig(BaseSettings):
MAIL_TEMPLATING_TIMEOUT: int = Field( MAIL_TEMPLATING_TIMEOUT: int = Field(
description=""" description="""
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates. Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
Only available in sandbox mode.""", Only available in sandbox mode.""",
default=3, default=3,
) )
@ -905,6 +974,11 @@ class DataSetConfig(BaseSettings):
default=True, default=True,
) )
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
description="Maximum number of segments for dataset segments API (0 for unlimited)",
default=0,
)
class WorkspaceConfig(BaseSettings): class WorkspaceConfig(BaseSettings):
""" """
@ -980,6 +1054,44 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable check upgradable plugin task", description="Enable check upgradable plugin task",
default=True, default=True,
) )
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
description="Enable workflow schedule poller task",
default=True,
)
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
description="Workflow schedule poller interval in minutes",
default=1,
)
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
description="Maximum number of schedules to process in each poll batch",
default=100,
)
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
default=0,
)
# Trigger provider refresh (simple version)
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
description="Enable trigger provider refresh poller",
default=True,
)
TRIGGER_PROVIDER_REFRESH_INTERVAL: int = Field(
description="Trigger provider refresh poller interval in minutes",
default=1,
)
TRIGGER_PROVIDER_REFRESH_BATCH_SIZE: int = Field(
description="Max trigger subscriptions to process per tick",
default=200,
)
TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field(
description="Proactive credential refresh threshold in seconds",
default=180,
)
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
description="Proactive subscription refresh threshold in seconds",
default=60 * 60,
)
class PositionConfig(BaseSettings): class PositionConfig(BaseSettings):
@ -1078,7 +1190,7 @@ class AccountConfig(BaseSettings):
class WorkflowLogConfig(BaseSettings): class WorkflowLogConfig(BaseSettings):
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup") WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=False, description="Enable workflow run log cleanup")
WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs") WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs")
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field( WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
default=100, description="Batch size for workflow run log cleanup operations" default=100, description="Batch size for workflow run log cleanup operations"
@ -1097,12 +1209,21 @@ class SwaggerUIConfig(BaseSettings):
) )
class TenantIsolatedTaskQueueConfig(BaseSettings):
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
default=1,
)
class FeatureConfig( class FeatureConfig(
# place the configs in alphabet order # place the configs in alphabet order
AppExecutionConfig, AppExecutionConfig,
AuthConfig, # Changed from OAuthConfig to AuthConfig AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig, BillingConfig,
CodeExecutionSandboxConfig, CodeExecutionSandboxConfig,
TriggerConfig,
AsyncWorkflowConfig,
PluginConfig, PluginConfig,
MarketplaceConfig, MarketplaceConfig,
DataSetConfig, DataSetConfig,
@ -1121,6 +1242,7 @@ class FeatureConfig(
RagEtlConfig, RagEtlConfig,
RepositoryConfig, RepositoryConfig,
SecurityConfig, SecurityConfig,
TenantIsolatedTaskQueueConfig,
ToolConfig, ToolConfig,
UpdateConfig, UpdateConfig,
WorkflowConfig, WorkflowConfig,

View File

@ -18,6 +18,7 @@ from .storage.opendal_storage_config import OpenDALStorageConfig
from .storage.supabase_storage_config import SupabaseStorageConfig from .storage.supabase_storage_config import SupabaseStorageConfig
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
from .vdb.analyticdb_config import AnalyticdbConfig from .vdb.analyticdb_config import AnalyticdbConfig
from .vdb.baidu_vector_config import BaiduVectorDBConfig from .vdb.baidu_vector_config import BaiduVectorDBConfig
from .vdb.chroma_config import ChromaConfig from .vdb.chroma_config import ChromaConfig
@ -144,7 +145,7 @@ class DatabaseConfig(BaseSettings):
default="postgresql", default="postgresql",
) )
@computed_field # type: ignore[misc] @computed_field # type: ignore[prop-decorator]
@property @property
def SQLALCHEMY_DATABASE_URI(self) -> str: def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = ( db_extras = (
@ -197,7 +198,7 @@ class DatabaseConfig(BaseSettings):
default=os.cpu_count() or 1, default=os.cpu_count() or 1,
) )
@computed_field # type: ignore[misc] @computed_field # type: ignore[prop-decorator]
@property @property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
# Parse DB_EXTRAS for 'options' # Parse DB_EXTRAS for 'options'
@ -330,6 +331,7 @@ class MiddlewareConfig(
ClickzettaConfig, ClickzettaConfig,
HuaweiCloudConfig, HuaweiCloudConfig,
MilvusConfig, MilvusConfig,
AlibabaCloudMySQLConfig,
MyScaleConfig, MyScaleConfig,
OpenSearchConfig, OpenSearchConfig,
OracleConfig, OracleConfig,

View File

@ -0,0 +1,54 @@
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AlibabaCloudMySQLConfig(BaseSettings):
"""
Configuration settings for AlibabaCloud MySQL vector database
"""
ALIBABACLOUD_MYSQL_HOST: str = Field(
description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')",
default="localhost",
)
ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field(
description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)",
default=3306,
)
ALIBABACLOUD_MYSQL_USER: str = Field(
description="Username for authenticating with AlibabaCloud MySQL (default is 'root')",
default="root",
)
ALIBABACLOUD_MYSQL_PASSWORD: str = Field(
description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)",
default="",
)
ALIBABACLOUD_MYSQL_DATABASE: str = Field(
description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')",
default="dify",
)
ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field(
description="Maximum number of connections in the connection pool",
default=5,
)
ALIBABACLOUD_MYSQL_CHARSET: str = Field(
description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')",
default="utf8mb4",
)
ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field(
description="Distance function used for vector similarity search in AlibabaCloud MySQL "
"(e.g., 'cosine', 'euclidean')",
default="cosine",
)
ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field(
description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)",
default=6,
)

View File

@ -1,23 +1,24 @@
from enum import Enum from enum import StrEnum
from typing import Literal from typing import Literal
from pydantic import Field, PositiveInt from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
class AuthMethod(StrEnum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
class OpenSearchConfig(BaseSettings): class OpenSearchConfig(BaseSettings):
""" """
Configuration settings for OpenSearch Configuration settings for OpenSearch
""" """
class AuthMethod(Enum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
OPENSEARCH_HOST: str | None = Field( OPENSEARCH_HOST: str | None = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')", description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None, default=None,

View File

@ -22,6 +22,11 @@ class WeaviateConfig(BaseSettings):
default=True, default=True,
) )
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
default=None,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field( WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description="Number of objects to be processed in a single batch operation (default is 100)", description="Number of objects to be processed in a single batch operation (default is 100)",
default=100, default=100,

View File

@ -1,4 +1,5 @@
from configs import dify_config from configs import dify_config
from libs.collection_utils import convert_to_lower_and_upper_set
HIDDEN_VALUE = "[__HIDDEN__]" HIDDEN_VALUE = "[__HIDDEN__]"
UNKNOWN_VALUE = "[__UNKNOWN__]" UNKNOWN_VALUE = "[__UNKNOWN__]"
@ -6,24 +7,39 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3 DEFAULT_FILE_NUMBER_LIMITS = 3
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"] VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"] AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
_doc_extensions: set[str]
_doc_extensions: list[str]
if dify_config.ETL_TYPE == "Unstructured": if dify_config.ETL_TYPE == "Unstructured":
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] _doc_extensions = {
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) "txt",
"markdown",
"md",
"mdx",
"pdf",
"html",
"htm",
"xlsx",
"xls",
"vtt",
"properties",
"doc",
"docx",
"csv",
"eml",
"msg",
"pptx",
"xml",
"epub",
}
if dify_config.UNSTRUCTURED_API_URL: if dify_config.UNSTRUCTURED_API_URL:
_doc_extensions.append("ppt") _doc_extensions.add("ppt")
else: else:
_doc_extensions = [ _doc_extensions = {
"txt", "txt",
"markdown", "markdown",
"md", "md",
@ -37,5 +53,18 @@ else:
"csv", "csv",
"vtt", "vtt",
"properties", "properties",
] }
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions] DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
# console
COOKIE_NAME_ACCESS_TOKEN = "access_token"
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
# webapp
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
COOKIE_NAME_PASSPORT = "passport"
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
HEADER_NAME_APP_CODE = "X-App-Code"
HEADER_NAME_PASSPORT = "X-App-Passport"

View File

@ -31,3 +31,9 @@ def supported_language(lang):
error = f"{lang} is not a valid language." error = f"{lang} is not a valid language."
raise ValueError(error) raise ValueError(error)
def get_valid_language(lang: str | None) -> str:
if lang and lang in languages:
return lang
return languages[0]

View File

@ -9,6 +9,7 @@ if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.trigger.provider import PluginTriggerProviderController
""" """
@ -41,3 +42,11 @@ datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginPro
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("datasource_plugin_providers_lock") ContextVar("datasource_plugin_providers_lock")
) )
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
ContextVar("plugin_trigger_providers")
)
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_trigger_providers_lock")
)

View File

@ -25,6 +25,12 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415 code = 415
class BlockedFileExtensionError(BaseHTTPException):
error_code = "file_extension_blocked"
description = "The file extension is blocked for security reasons."
code = 400
class TooManyFilesError(BaseHTTPException): class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files" error_code = "too_many_files"
description = "Only one file is allowed." description = "Only one file is allowed."

View File

@ -24,7 +24,7 @@ except ImportError:
) )
else: else:
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2) warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
magic = None # type: ignore magic = None # type: ignore[assignment]
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -66,6 +66,7 @@ from .app import (
workflow_draft_variable, workflow_draft_variable,
workflow_run, workflow_run,
workflow_statistic, workflow_statistic,
workflow_trigger,
) )
# Import auth controllers # Import auth controllers
@ -126,6 +127,7 @@ from .workspace import (
models, models,
plugin, plugin,
tool_providers, tool_providers,
trigger_providers,
workspace, workspace,
) )
@ -196,6 +198,7 @@ __all__ = [
"statistic", "statistic",
"tags", "tags",
"tool_providers", "tool_providers",
"trigger_providers",
"version", "version",
"website", "website",
"workflow", "workflow",
@ -203,5 +206,6 @@ __all__ = [
"workflow_draft_variable", "workflow_draft_variable",
"workflow_run", "workflow_run",
"workflow_statistic", "workflow_statistic",
"workflow_trigger",
"workspace", "workspace",
] ]

View File

@ -15,6 +15,7 @@ from constants.languages import supported_language
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.wraps import only_edition_cloud from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp from models.model import App, InstalledApp, RecommendedApp
@ -24,19 +25,9 @@ def admin_required(view: Callable[P, R]):
if not dify_config.ADMIN_API_KEY: if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.") raise Unauthorized("API key is invalid.")
auth_header = request.headers.get("Authorization") auth_token = extract_access_token(request)
if auth_header is None: if not auth_token:
raise Unauthorized("Authorization header is missing.") raise Unauthorized("Authorization header is missing.")
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if auth_token != dify_config.ADMIN_API_KEY: if auth_token != dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.") raise Unauthorized("API key is invalid.")
@ -70,15 +61,17 @@ class InsertExploreAppListApi(Resource):
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("desc", type=str, location="json") .add_argument("app_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("copyright", type=str, location="json") .add_argument("desc", type=str, location="json")
parser.add_argument("privacy_policy", type=str, location="json") .add_argument("copyright", type=str, location="json")
parser.add_argument("custom_disclaimer", type=str, location="json") .add_argument("privacy_policy", type=str, location="json")
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json") .add_argument("custom_disclaimer", type=str, location="json")
parser.add_argument("category", type=str, required=True, nullable=False, location="json") .add_argument("language", type=supported_language, required=True, nullable=False, location="json")
parser.add_argument("position", type=int, required=True, nullable=False, location="json") .add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("position", type=int, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()

View File

@ -1,5 +1,4 @@
import flask_restx import flask_restx
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus from flask_restx._http import HTTPStatus
from sqlalchemy import select from sqlalchemy import select
@ -8,12 +7,12 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from libs.helper import TimestampField
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset from models.dataset import Dataset
from models.model import ApiToken, App from models.model import ApiToken, App
from . import api, console_ns from . import api, console_ns
from .wraps import account_initialization_required, setup_required from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = { api_key_fields = {
"id": fields.String, "id": fields.String,
@ -57,7 +56,9 @@ class BaseApiKeyListResource(Resource):
def get(self, resource_id): def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
keys = db.session.scalars( keys = db.session.scalars(
select(ApiToken).where( select(ApiToken).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
@ -66,13 +67,12 @@ class BaseApiKeyListResource(Resource):
return {"items": keys} return {"items": keys}
@marshal_with(api_key_fields) @marshal_with(api_key_fields)
@edit_permission_required
def post(self, resource_id): def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _, current_tenant_id = current_account_with_tenant()
if not current_user.is_editor: _get_resource(resource_id, current_tenant_id, self.resource_model)
raise Forbidden()
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
@ -89,7 +89,7 @@ class BaseApiKeyListResource(Resource):
key = ApiToken.generate_api_key(self.token_prefix or "", 24) key = ApiToken.generate_api_key(self.token_prefix or "", 24)
api_token = ApiToken() api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id) setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id api_token.tenant_id = current_tenant_id
api_token.token = key api_token.token = key
api_token.type = self.resource_type api_token.type = self.resource_type
db.session.add(api_token) db.session.add(api_token)
@ -108,7 +108,8 @@ class BaseApiKeyResource(Resource):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
api_key_id = str(api_key_id) api_key_id = str(api_key_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
@ -152,11 +153,6 @@ class AppApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for an app""" """Create a new API key for an app"""
return super().post(resource_id) return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app" resource_type = "app"
resource_model = App resource_model = App
resource_id_field = "app_id" resource_id_field = "app_id"
@ -173,11 +169,6 @@ class AppApiKeyResource(BaseApiKeyResource):
"""Delete an API key for an app""" """Delete an API key for an app"""
return super().delete(resource_id, api_key_id) return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app" resource_type = "app"
resource_model = App resource_model = App
resource_id_field = "app_id" resource_id_field = "app_id"
@ -202,11 +193,6 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for a dataset""" """Create a new API key for a dataset"""
return super().post(resource_id) return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset" resource_type = "dataset"
resource_model = Dataset resource_model = Dataset
resource_id_field = "dataset_id" resource_id_field = "dataset_id"
@ -223,11 +209,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
"""Delete an API key for a dataset""" """Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id) return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset" resource_type = "dataset"
resource_model = Dataset resource_model = Dataset
resource_id_field = "dataset_id" resource_id_field = "dataset_id"

View File

@ -5,18 +5,20 @@ from controllers.console.wraps import account_initialization_required, setup_req
from libs.login import login_required from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService from services.advanced_prompt_template_service import AdvancedPromptTemplateService
parser = (
reqparse.RequestParser()
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
)
@console_ns.route("/app/prompt-templates") @console_ns.route("/app/prompt-templates")
class AdvancedPromptTemplateList(Resource): class AdvancedPromptTemplateList(Resource):
@api.doc("get_advanced_prompt_templates") @api.doc("get_advanced_prompt_templates")
@api.doc(description="Get advanced prompt templates based on app mode and model configuration") @api.doc(description="Get advanced prompt templates based on app mode and model configuration")
@api.expect( @api.expect(parser)
api.parser()
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
.add_argument("has_context", type=str, default="true", location="args", help="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
)
@api.response( @api.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data")) 200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
) )
@ -25,11 +27,6 @@ class AdvancedPromptTemplateList(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
parser = reqparse.RequestParser()
parser.add_argument("app_mode", type=str, required=True, location="args")
parser.add_argument("model_mode", type=str, required=True, location="args")
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
parser.add_argument("model_name", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args) return AdvancedPromptTemplateService.get_prompt(args)

View File

@ -8,17 +8,19 @@ from libs.login import login_required
from models.model import AppMode from models.model import AppMode
from services.agent_service import AgentService from services.agent_service import AgentService
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
)
@console_ns.route("/apps/<uuid:app_id>/agent/logs") @console_ns.route("/apps/<uuid:app_id>/agent/logs")
class AgentLogApi(Resource): class AgentLogApi(Resource):
@api.doc("get_agent_logs") @api.doc("get_agent_logs")
@api.doc(description="Get agent execution logs for an application") @api.doc(description="Get agent execution logs for an application")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.expect( @api.expect(parser)
api.parser()
.add_argument("message_id", type=str, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID")
)
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))) @api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
@api.response(400, "Invalid request parameters") @api.response(400, "Invalid request parameters")
@setup_required @setup_required
@ -27,10 +29,6 @@ class AgentLogApi(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model): def get(self, app_model):
"""Get agent logs""" """Get agent logs"""
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])

View File

@ -1,15 +1,14 @@
from typing import Literal from typing import Literal
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required,
setup_required, setup_required,
) )
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -17,6 +16,7 @@ from fields.annotation_fields import (
annotation_fields, annotation_fields,
annotation_hit_history_fields, annotation_hit_history_fields,
) )
from libs.helper import uuid_value
from libs.login import login_required from libs.login import login_required
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
@ -42,15 +42,15 @@ class AnnotationReplyActionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]): def post(self, app_id, action: Literal["enable", "disable"]):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("score_threshold", required=True, type=float, location="json") reqparse.RequestParser()
parser.add_argument("embedding_provider_name", required=True, type=str, location="json") .add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json") .add_argument("embedding_provider_name", required=True, type=str, location="json")
.add_argument("embedding_model_name", required=True, type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if action == "enable": if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_id) result = AppAnnotationService.enable_app_annotation(args, app_id)
@ -69,10 +69,8 @@ class AppAnnotationSettingDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_id): def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id) result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
return result, 200 return result, 200
@ -98,15 +96,12 @@ class AppAnnotationSettingUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, app_id, annotation_setting_id): def post(self, app_id, annotation_setting_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id) annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("score_threshold", required=True, type=float, location="json")
args = parser.parse_args() args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
@ -124,10 +119,8 @@ class AnnotationReplyActionStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id, action): def get(self, app_id, job_id, action):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id) job_id = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
cache_result = redis_client.get(app_annotation_job_key) cache_result = redis_client.get(app_annotation_job_key)
@ -159,10 +152,8 @@ class AnnotationApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_id): def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str) keyword = request.args.get("keyword", default="", type=str)
@ -185,8 +176,10 @@ class AnnotationApi(Resource):
api.model( api.model(
"CreateAnnotationRequest", "CreateAnnotationRequest",
{ {
"question": fields.String(required=True, description="Question text"), "message_id": fields.String(description="Message ID (optional)"),
"answer": fields.String(required=True, description="Answer text"), "question": fields.String(description="Question text (required when message_id not provided)"),
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
"content": fields.String(description="Content text (use 'answer' or 'content')"),
"annotation_reply": fields.Raw(description="Annotation reply data"), "annotation_reply": fields.Raw(description="Annotation reply data"),
}, },
) )
@ -198,25 +191,26 @@ class AnnotationApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
@edit_permission_required
def post(self, app_id): def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("question", required=True, type=str, location="json") reqparse.RequestParser()
parser.add_argument("answer", required=True, type=str, location="json") .add_argument("message_id", required=False, type=uuid_value, location="json")
.add_argument("question", required=False, type=str, location="json")
.add_argument("answer", required=False, type=str, location="json")
.add_argument("content", required=False, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args() args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
return annotation return annotation
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, app_id): def delete(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
# Use request.args.getlist to get annotation_ids array directly # Use request.args.getlist to get annotation_ids array directly
@ -249,16 +243,21 @@ class AnnotationExportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_id): def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)} response = {"data": marshal(annotation_list, annotation_fields)}
return response, 200 return response, 200
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>") @console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource): class AnnotationUpdateDeleteApi(Resource):
@api.doc("update_delete_annotation") @api.doc("update_delete_annotation")
@ -267,20 +266,16 @@ class AnnotationUpdateDeleteApi(Resource):
@api.response(200, "Annotation updated successfully", annotation_fields) @api.response(200, "Annotation updated successfully", annotation_fields)
@api.response(204, "Annotation deleted successfully") @api.response(204, "Annotation deleted successfully")
@api.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@api.expect(parser)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_id, annotation_id): def post(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation return annotation
@ -288,10 +283,8 @@ class AnnotationUpdateDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, app_id, annotation_id): def delete(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id) AppAnnotationService.delete_app_annotation(app_id, annotation_id)
@ -310,10 +303,8 @@ class AnnotationBatchImportApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id): def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
# check file # check file
if "file" not in request.files: if "file" not in request.files:
@ -341,10 +332,8 @@ class AnnotationBatchImportStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id): def get(self, app_id, job_id):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id) job_id = str(job_id)
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
@ -376,10 +365,8 @@ class AnnotationHitHistoryListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_id, annotation_id): def get(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
app_id = str(app_id) app_id = str(app_id)

View File

@ -1,7 +1,5 @@
import uuid import uuid
from typing import cast
from flask_login import current_user
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -12,14 +10,16 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required,
enterprise_license_required, enterprise_license_required,
setup_required, setup_required,
) )
from core.ops.ops_trace_manager import OpsTraceManager from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account, App from libs.validators import validate_description_length
from models import App
from services.app_dsl_service import AppDslService, ImportMode from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
@ -28,12 +28,6 @@ from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/apps") @console_ns.route("/apps")
class AppListApi(Resource): class AppListApi(Resource):
@api.doc("list_apps") @api.doc("list_apps")
@ -61,6 +55,7 @@ class AppListApi(Resource):
@enterprise_license_required @enterprise_license_required
def get(self): def get(self):
"""Get app list""" """Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
def uuid_list(value): def uuid_list(value):
try: try:
@ -68,34 +63,36 @@ class AppListApi(Resource):
except ValueError: except ValueError:
abort(400, message="Invalid UUID format in tag_ids.") abort(400, message="Invalid UUID format in tag_ids.")
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument( .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
"mode", .add_argument(
type=str, "mode",
choices=[ type=str,
"completion", choices=[
"chat", "completion",
"advanced-chat", "chat",
"workflow", "advanced-chat",
"agent-chat", "workflow",
"channel", "agent-chat",
"all", "channel",
], "all",
default="all", ],
location="args", default="all",
required=False, location="args",
required=False,
)
.add_argument("name", type=str, location="args", required=False)
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
) )
parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
args = parser.parse_args() args = parser.parse_args()
# get app list # get app list
app_service = AppService() app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args) app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
if not app_pagination: if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@ -134,30 +131,26 @@ class AppListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self): def post(self):
"""Create app""" """Create app"""
parser = reqparse.RequestParser() current_user, current_tenant_id = current_account_with_tenant()
parser.add_argument("name", type=str, required=True, location="json") parser = (
parser.add_argument("description", type=_validate_description_length, location="json") reqparse.RequestParser()
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") .add_argument("name", type=str, required=True, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_background", type=str, location="json") .add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if "mode" not in args or args["mode"] is None: if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required") raise BadRequest("mode is required")
app_service = AppService() app_service = AppService()
if not isinstance(current_user, Account): app = app_service.create_app(current_tenant_id, args, current_user)
raise ValueError("current_user must be an Account instance")
if current_user.current_tenant_id is None:
raise ValueError("current_user.current_tenant_id cannot be None")
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
return app, 201 return app, 201
@ -210,21 +203,20 @@ class AppApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@edit_permission_required
@marshal_with(app_detail_fields_with_site) @marshal_with(app_detail_fields_with_site)
def put(self, app_model): def put(self, app_model):
"""Update app""" """Update app"""
# The role of the current user in the ta table must be admin, owner, or editor parser = (
if not current_user.is_editor: reqparse.RequestParser()
raise Forbidden() .add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=validate_description_length, location="json")
parser = reqparse.RequestParser() .add_argument("icon_type", type=str, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json") .add_argument("icon", type=str, location="json")
parser.add_argument("description", type=_validate_description_length, location="json") .add_argument("icon_background", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("use_icon_as_answer_icon", type=bool, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("max_active_requests", type=int, location="json")
parser.add_argument("icon_background", type=str, location="json") )
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
parser.add_argument("max_active_requests", type=int, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -253,12 +245,9 @@ class AppApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, app_model): def delete(self, app_model):
"""Delete app""" """Delete app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
app_service = AppService() app_service = AppService()
app_service.delete_app(app_model) app_service.delete_app(app_model)
@ -288,28 +277,29 @@ class AppCopyApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@edit_permission_required
@marshal_with(app_detail_fields_with_site) @marshal_with(app_detail_fields_with_site)
def post(self, app_model): def post(self, app_model):
"""Copy app""" """Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("name", type=str, location="json") reqparse.RequestParser()
parser.add_argument("description", type=_validate_description_length, location="json") .add_argument("name", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("icon_type", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json") .add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
account = cast(Account, current_user)
result = import_service.import_app( result = import_service.import_app(
account=account, account=current_user,
import_mode=ImportMode.YAML_CONTENT.value, import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content, yaml_content=yaml_content,
name=args.get("name"), name=args.get("name"),
description=args.get("description"), description=args.get("description"),
@ -345,16 +335,15 @@ class AppExportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_model): def get(self, app_model):
"""Export app""" """Export app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params # Add include_secret params
parser = reqparse.RequestParser() parser = (
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") reqparse.RequestParser()
parser.add_argument("workflow_id", type=str, location="args") .add_argument("include_secret", type=inputs.boolean, default=False, location="args")
.add_argument("workflow_id", type=str, location="args")
)
args = parser.parse_args() args = parser.parse_args()
return { return {
@ -364,25 +353,23 @@ class AppExportApi(Resource):
} }
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
@console_ns.route("/apps/<uuid:app_id>/name") @console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource): class AppNameApi(Resource):
@api.doc("check_app_name") @api.doc("check_app_name")
@api.doc(description="Check if app name is available") @api.doc(description="Check if app name is available")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check")) @api.expect(parser)
@api.response(200, "Name availability checked") @api.response(200, "Name availability checked")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -413,14 +400,13 @@ class AppIconApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor parser = (
if not current_user.is_editor: reqparse.RequestParser()
raise Forbidden() .add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
parser = reqparse.RequestParser() )
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -446,13 +432,9 @@ class AppSiteStatus(Resource):
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enable_site", type=bool, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -480,11 +462,11 @@ class AppApiStatus(Resource):
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
parser.add_argument("enable_api", type=bool, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -525,13 +507,14 @@ class AppTraceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, app_id): def post(self, app_id):
# add app trace # add app trace
if not current_user.is_editor: parser = (
raise Forbidden() reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("enabled", type=bool, required=True, location="json")
parser.add_argument("enabled", type=bool, required=True, location="json") .add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument("tracing_provider", type=str, required=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
OpsTraceManager.update_app_tracing_config( OpsTraceManager.update_app_tracing_config(

View File

@ -1,20 +1,17 @@
from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required,
setup_required, setup_required,
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account
from models.model import App from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
@ -22,36 +19,39 @@ from services.feature_service import FeatureService
from .. import console_ns from .. import console_ns
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("app_id", type=str, location="json")
)
@console_ns.route("/apps/imports") @console_ns.route("/apps/imports")
class AppImportApi(Resource): class AppImportApi(Resource):
@api.expect(parser)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_import_fields) @marshal_with(app_import_fields)
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self): def post(self):
# Check user role first # Check user role first
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("app_id", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
# Import app # Import app
account = cast(Account, current_user) account = current_user
result = import_service.import_app( result = import_service.import_app(
account=account, account=account,
import_mode=args["mode"], import_mode=args["mode"],
@ -70,9 +70,9 @@ class AppImportApi(Resource):
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result # Return appropriate status code based on result
status = result.status status = result.status
if status == ImportStatus.FAILED.value: if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value: elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202 return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@ -83,21 +83,21 @@ class AppImportConfirmApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_import_fields) @marshal_with(app_import_fields)
@edit_permission_required
def post(self, import_id): def post(self, import_id):
# Check user role first # Check user role first
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
# Confirm import # Confirm import
account = cast(Account, current_user) account = current_user
result = import_service.confirm_import(import_id=import_id, account=account) result = import_service.confirm_import(import_id=import_id, account=account)
session.commit() session.commit()
# Return appropriate status code based on result # Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value: if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@ -109,10 +109,8 @@ class AppImportCheckDependenciesApi(Resource):
@get_app_model @get_app_model
@account_initialization_required @account_initialization_required
@marshal_with(app_import_check_dependencies_fields) @marshal_with(app_import_check_dependencies_fields)
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model) result = import_service.check_dependencies(app_model=app_model)

View File

@ -111,11 +111,13 @@ class ChatMessageTextApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, app_model: App): def post(self, app_model: App):
try: try:
parser = reqparse.RequestParser() parser = (
parser.add_argument("message_id", type=str, location="json") reqparse.RequestParser()
parser.add_argument("text", type=str, location="json") .add_argument("message_id", type=str, location="json")
parser.add_argument("voice", type=str, location="json") .add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json") .add_argument("voice", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args() args = parser.parse_args()
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
@ -166,8 +168,7 @@ class TextModesApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
try: try:
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
parser.add_argument("language", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
response = AudioService.transcript_tts_voices( response = AudioService.transcript_tts_voices(

View File

@ -2,7 +2,7 @@ import logging
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.console import api, console_ns from controllers.console import api, console_ns
@ -15,7 +15,7 @@ from controllers.console.app.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -64,13 +64,15 @@ class CompletionMessageApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model): def post(self, app_model):
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, location="json") reqparse.RequestParser()
parser.add_argument("query", type=str, location="json", default="") .add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json") .add_argument("query", type=str, location="json", default="")
parser.add_argument("model_config", type=dict, required=True, location="json") .add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") .add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] != "blocking" streaming = args["response_mode"] != "blocking"
@ -151,22 +153,19 @@ class ChatMessageApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model): def post(self, app_model):
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
if not current_user.has_edit_permission: .add_argument("query", type=str, required=True, location="json")
raise Forbidden() .add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
parser = reqparse.RequestParser() .add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("inputs", type=dict, required=True, location="json") .add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("query", type=str, required=True, location="json") .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("files", type=list, required=False, location="json") .add_argument("retriever_from", type=str, required=False, default="dev", location="json")
parser.add_argument("model_config", type=dict, required=True, location="json") )
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] != "blocking" streaming = args["response_mode"] != "blocking"

View File

@ -1,17 +1,14 @@
from datetime import datetime
import pytz # pip install pytz
import sqlalchemy as sa import sqlalchemy as sa
from flask_login import current_user from flask import abort
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy import func, or_ from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from fields.conversation_fields import ( from fields.conversation_fields import (
@ -20,10 +17,10 @@ from fields.conversation_fields import (
conversation_pagination_fields, conversation_pagination_fields,
conversation_with_summary_pagination_fields, conversation_with_summary_pagination_fields,
) )
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account, Conversation, EndUser, Message, MessageAnnotation from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode from models.model import AppMode
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
@ -57,18 +54,24 @@ class CompletionConversationApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_fields) @marshal_with(conversation_pagination_fields)
@edit_permission_required
def get(self, app_model): def get(self, app_model):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
parser = reqparse.RequestParser() reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") .add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument( .add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" "annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
) )
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
query = sa.select(Conversation).where( query = sa.select(Conversation).where(
@ -84,25 +87,18 @@ class CompletionConversationApi(Resource):
) )
account = current_user account = current_user
timezone = pytz.timezone(account.timezone) assert account.timezone is not None
utc_timezone = pytz.utc
if args["start"]: try:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
start_datetime = start_datetime.replace(second=0) except ValueError as e:
abort(400, description=str(e))
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
if start_datetime_utc:
query = query.where(Conversation.created_at >= start_datetime_utc) query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]: if end_datetime_utc:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime_utc = end_datetime_utc.replace(second=59)
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
query = query.where(Conversation.created_at < end_datetime_utc) query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file # FIXME, the type ignore in this file
@ -137,9 +133,8 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_fields) @marshal_with(conversation_message_detail_fields)
@edit_permission_required
def get(self, app_model, conversation_id): def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id) return _get_conversation(app_model, conversation_id)
@ -154,14 +149,12 @@ class CompletionConversationDetailApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def delete(self, app_model, conversation_id): def delete(self, app_model, conversation_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -206,26 +199,32 @@ class ChatConversationApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_with_summary_pagination_fields) @marshal_with(conversation_with_summary_pagination_fields)
@edit_permission_required
def get(self, app_model): def get(self, app_model):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
parser = reqparse.RequestParser() reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") .add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument( .add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" "annotation_status",
) type=str,
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args") choices=["annotated", "not_annotated", "all"],
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args") default="all",
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") location="args",
parser.add_argument( )
"sort_by", .add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
type=str, .add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
choices=["created_at", "-created_at", "updated_at", "-updated_at"], .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
required=False, .add_argument(
default="-updated_at", "sort_by",
location="args", type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
) )
args = parser.parse_args() args = parser.parse_args()
@ -260,29 +259,22 @@ class ChatConversationApi(Resource):
) )
account = current_user account = current_user
timezone = pytz.timezone(account.timezone) assert account.timezone is not None
utc_timezone = pytz.utc
if args["start"]: try:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
start_datetime = start_datetime.replace(second=0) except ValueError as e:
abort(400, description=str(e))
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
if start_datetime_utc:
match args["sort_by"]: match args["sort_by"]:
case "updated_at" | "-updated_at": case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc) query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _: case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at >= start_datetime_utc) query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]: if end_datetime_utc:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime_utc = end_datetime_utc.replace(second=59)
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
match args["sort_by"]: match args["sort_by"]:
case "updated_at" | "-updated_at": case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc) query = query.where(Conversation.updated_at <= end_datetime_utc)
@ -309,7 +301,7 @@ class ChatConversationApi(Resource):
) )
if app_model.mode == AppMode.ADVANCED_CHAT: if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
match args["sort_by"]: match args["sort_by"]:
case "created_at": case "created_at":
@ -341,9 +333,8 @@ class ChatConversationDetailApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_detail_fields) @marshal_with(conversation_detail_fields)
@edit_permission_required
def get(self, app_model, conversation_id): def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id) return _get_conversation(app_model, conversation_id)
@ -358,14 +349,12 @@ class ChatConversationDetailApi(Resource):
@login_required @login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, app_model, conversation_id): def delete(self, app_model, conversation_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -374,6 +363,7 @@ class ChatConversationDetailApi(Resource):
def _get_conversation(app_model, conversation_id): def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)

View File

@ -29,8 +29,7 @@ class ConversationVariablesApi(Resource):
@get_app_model(mode=AppMode.ADVANCED_CHAT) @get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_fields) @marshal_with(paginated_conversation_variable_fields)
def get(self, app_model): def get(self, app_model):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
parser.add_argument("conversation_id", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
stmt = ( stmt = (

View File

@ -1,6 +1,5 @@
from collections.abc import Sequence from collections.abc import Sequence
from flask_login import current_user
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns from controllers.console import api, console_ns
@ -12,12 +11,13 @@ from controllers.console.app.error import (
) )
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import App from models import App
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@ -43,16 +43,18 @@ class RuleGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") .add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") .add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try: try:
rules = LLMGenerator.generate_rule_config( rules = LLMGenerator.generate_rule_config(
tenant_id=account.current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
no_variable=args["no_variable"], no_variable=args["no_variable"],
@ -93,17 +95,19 @@ class RuleCodeGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") .add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") .add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json") .add_argument("no_variable", type=bool, required=True, default=False, location="json")
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try: try:
code_result = LLMGenerator.generate_code( code_result = LLMGenerator.generate_code(
tenant_id=account.current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
code_language=args["code_language"], code_language=args["code_language"],
@ -140,15 +144,17 @@ class RuleStructuredOutputGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") .add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try: try:
structured_output = LLMGenerator.generate_structured_output( structured_output = LLMGenerator.generate_structured_output(
tenant_id=account.current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
) )
@ -189,22 +195,23 @@ class InstructionGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("flow_id", type=str, required=True, default="", location="json") reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=False, default="", location="json") .add_argument("flow_id", type=str, required=True, default="", location="json")
parser.add_argument("current", type=str, required=False, default="", location="json") .add_argument("node_id", type=str, required=False, default="", location="json")
parser.add_argument("language", type=str, required=False, default="javascript", location="json") .add_argument("current", type=str, required=False, default="", location="json")
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") .add_argument("language", type=str, required=False, default="javascript", location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") .add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json") .add_argument("model_config", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() .add_argument("ideal_output", type=str, required=False, default="", location="json")
code_template = (
Python3CodeProvider.get_default_code()
if args["language"] == "python"
else (JavascriptCodeProvider.get_default_code())
if args["language"] == "javascript"
else ""
) )
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args["language"])), None
)
code_template = code_provider.get_default_code() if code_provider else ""
try: try:
# Generate from nothing for a workflow node # Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "": if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
@ -222,21 +229,21 @@ class InstructionGenerateApi(Resource):
match node_type: match node_type:
case "llm": case "llm":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_user.current_tenant_id, current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
no_variable=True, no_variable=True,
) )
case "agent": case "agent":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_user.current_tenant_id, current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
no_variable=True, no_variable=True,
) )
case "code": case "code":
return LLMGenerator.generate_code( return LLMGenerator.generate_code(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
code_language=args["language"], code_language=args["language"],
@ -245,7 +252,7 @@ class InstructionGenerateApi(Resource):
return {"error": f"invalid node type: {node_type}"} return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy( return LLMGenerator.instruction_modify_legacy(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
flow_id=args["flow_id"], flow_id=args["flow_id"],
current=args["current"], current=args["current"],
instruction=args["instruction"], instruction=args["instruction"],
@ -254,7 +261,7 @@ class InstructionGenerateApi(Resource):
) )
if args["node_id"] != "" and args["current"] != "": # For workflow node if args["node_id"] != "" and args["current"] != "": # For workflow node
return LLMGenerator.instruction_modify_workflow( return LLMGenerator.instruction_modify_workflow(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
flow_id=args["flow_id"], flow_id=args["flow_id"],
node_id=args["node_id"], node_id=args["node_id"],
current=args["current"], current=args["current"],
@ -293,8 +300,7 @@ class InstructionGenerationTemplateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
parser.add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args() args = parser.parse_args()
match args["type"]: match args["type"]:
case "prompt": case "prompt":

View File

@ -1,16 +1,15 @@
import json import json
from enum import StrEnum from enum import StrEnum
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_server_fields from fields.app_fields import app_server_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer from models.model import AppMCPServer
@ -25,9 +24,9 @@ class AppMCPServerController(Resource):
@api.doc(description="Get MCP server configuration for an application") @api.doc(description="Get MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields) @api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
@setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@setup_required
@get_app_model @get_app_model
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
def get(self, app_model): def get(self, app_model):
@ -48,17 +47,19 @@ class AppMCPServerController(Resource):
) )
@api.response(201, "MCP server configuration created successfully", app_server_fields) @api.response(201, "MCP server configuration created successfully", app_server_fields)
@api.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@login_required
@setup_required
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
@edit_permission_required
def post(self, app_model): def post(self, app_model):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise NotFound() parser = (
parser = reqparse.RequestParser() reqparse.RequestParser()
parser.add_argument("description", type=str, required=False, location="json") .add_argument("description", type=str, required=False, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json") .add_argument("parameters", type=dict, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
description = args.get("description") description = args.get("description")
@ -71,7 +72,7 @@ class AppMCPServerController(Resource):
parameters=json.dumps(args["parameters"], ensure_ascii=False), parameters=json.dumps(args["parameters"], ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE, status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id, app_id=app_model.id,
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
server_code=AppMCPServer.generate_server_code(16), server_code=AppMCPServer.generate_server_code(16),
) )
db.session.add(server) db.session.add(server)
@ -95,19 +96,20 @@ class AppMCPServerController(Resource):
@api.response(200, "MCP server configuration updated successfully", app_server_fields) @api.response(200, "MCP server configuration updated successfully", app_server_fields)
@api.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@api.response(404, "Server not found") @api.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model @get_app_model
@login_required
@setup_required
@account_initialization_required
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
@edit_permission_required
def put(self, app_model): def put(self, app_model):
if not current_user.is_editor: parser = (
raise NotFound() reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("id", type=str, required=True, location="json")
parser.add_argument("id", type=str, required=True, location="json") .add_argument("description", type=str, required=False, location="json")
parser.add_argument("description", type=str, required=False, location="json") .add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json") .add_argument("status", type=str, required=False, location="json")
parser.add_argument("status", type=str, required=False, location="json") )
args = parser.parse_args() args = parser.parse_args()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first() server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
if not server: if not server:
@ -142,13 +144,13 @@ class AppMCPServerRefreshController(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
@edit_permission_required
def get(self, server_id): def get(self, server_id):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise NotFound()
server = ( server = (
db.session.query(AppMCPServer) db.session.query(AppMCPServer)
.where(AppMCPServer.id == server_id) .where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_user.current_tenant_id) .where(AppMCPServer.tenant_id == current_tenant_id)
.first() .first()
) )
if not server: if not server:

View File

@ -3,7 +3,7 @@ import logging
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy import exists, select from sqlalchemy import exists, select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
@ -16,20 +16,18 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, edit_permission_required,
setup_required, setup_required,
) )
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields from fields.conversation_fields import message_detail_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService from services.message_service import MessageService
@ -56,19 +54,19 @@ class ChatMessageListApi(Resource):
) )
@api.response(200, "Success", message_infinite_scroll_pagination_fields) @api.response(200, "Success", message_infinite_scroll_pagination_fields)
@api.response(404, "Conversation not found") @api.response(404, "Conversation not found")
@setup_required
@login_required @login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required @account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
@edit_permission_required
def get(self, app_model): def get(self, app_model):
if not isinstance(current_user, Account) or not current_user.has_edit_permission: parser = (
raise Forbidden() reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser = reqparse.RequestParser() .add_argument("first_id", type=uuid_value, location="args")
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("first_id", type=uuid_value, location="args") )
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
conversation = ( conversation = (
@ -154,12 +152,13 @@ class MessageFeedbackApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, app_model): def post(self, app_model):
if current_user is None: current_user, _ = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("message_id", required=True, type=uuid_value, location="json") reqparse.RequestParser()
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") .add_argument("message_id", required=True, type=uuid_value, location="json")
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
)
args = parser.parse_args() args = parser.parse_args()
message_id = str(args["message_id"]) message_id = str(args["message_id"])
@ -193,47 +192,6 @@ class MessageFeedbackApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/annotations")
class MessageAnnotationApi(Resource):
@api.doc("create_message_annotation")
@api.doc(description="Create message annotation")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"MessageAnnotationRequest",
{
"message_id": fields.String(description="Message ID"),
"question": fields.String(required=True, description="Question text"),
"answer": fields.String(required=True, description="Answer text"),
"annotation_reply": fields.Raw(description="Annotation reply"),
},
)
)
@api.response(200, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@get_app_model
@marshal_with(annotation_fields)
def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
return annotation
@console_ns.route("/apps/<uuid:app_id>/annotations/count") @console_ns.route("/apps/<uuid:app_id>/annotations/count")
class MessageAnnotationCountApi(Resource): class MessageAnnotationCountApi(Resource):
@api.doc("get_annotation_count") @api.doc("get_annotation_count")
@ -270,6 +228,7 @@ class MessageSuggestedQuestionApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model, message_id): def get(self, app_model, message_id):
current_user, _ = current_account_with_tenant()
message_id = str(message_id) message_id = str(message_id)
try: try:
@ -304,12 +263,12 @@ class MessageApi(Resource):
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@api.response(200, "Message retrieved successfully", message_detail_fields) @api.response(200, "Message retrieved successfully", message_detail_fields)
@api.response(404, "Message not found") @api.response(404, "Message not found")
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
@marshal_with(message_detail_fields) @marshal_with(message_detail_fields)
def get(self, app_model, message_id): def get(self, app_model, message_id: str):
message_id = str(message_id) message_id = str(message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()

View File

@ -2,7 +2,6 @@ import json
from typing import cast from typing import cast
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -14,8 +13,8 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated from events.app_event import app_model_config_was_updated
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import login_required from libs.datetime_utils import naive_utc_now
from models.account import Account from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, AppModelConfig from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService from services.app_model_config_service import AppModelConfigService
@ -53,16 +52,14 @@ class ModelConfigResource(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model): def post(self, app_model):
"""Modify app model config""" """Modify app model config"""
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise Forbidden()
if not current_user.has_edit_permission: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
# validate config # validate config
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
config=cast(dict, request.json), config=cast(dict, request.json),
app_mode=AppMode.value_of(app_model.mode), app_mode=AppMode.value_of(app_model.mode),
) )
@ -90,16 +87,16 @@ class ModelConfigResource(Resource):
if not isinstance(tool, dict) or len(tool.keys()) <= 3: if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue continue
agent_tool_entity = AgentToolEntity(**tool) agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool # get tool
try: try:
tool_runtime = ToolManager.get_agent_tool_runtime( tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
app_id=app_model.id, app_id=app_model.id,
agent_tool=agent_tool_entity, agent_tool=agent_tool_entity,
) )
manager = ToolParameterConfigurationManager( manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
tool_runtime=tool_runtime, tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id, provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type, provider_type=agent_tool_entity.provider_type,
@ -124,7 +121,7 @@ class ModelConfigResource(Resource):
# encrypt agent tool parameters if it's secret-input # encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get("tools") or []: for tool in agent_mode.get("tools") or []:
agent_tool_entity = AgentToolEntity(**tool) agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool # get tool
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
@ -133,7 +130,7 @@ class ModelConfigResource(Resource):
else: else:
try: try:
tool_runtime = ToolManager.get_agent_tool_runtime( tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
app_id=app_model.id, app_id=app_model.id,
agent_tool=agent_tool_entity, agent_tool=agent_tool_entity,
) )
@ -141,7 +138,7 @@ class ModelConfigResource(Resource):
continue continue
manager = ToolParameterConfigurationManager( manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
tool_runtime=tool_runtime, tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id, provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type, provider_type=agent_tool_entity.provider_type,
@ -172,6 +169,8 @@ class ModelConfigResource(Resource):
db.session.flush() db.session.flush()
app_model.app_model_config_id = new_app_model_config.id app_model.app_model_config_id = new_app_model_config.id
app_model.updated_by = current_user.id
app_model.updated_at = naive_utc_now()
db.session.commit() db.session.commit()
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)

View File

@ -30,8 +30,7 @@ class TraceAppConfigApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
parser.add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -63,9 +62,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, app_id): def post(self, app_id):
"""Create a new trace app configuration""" """Create a new trace app configuration"""
parser = reqparse.RequestParser() parser = (
parser.add_argument("tracing_provider", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("tracing_config", type=dict, required=True, location="json") .add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -99,9 +100,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def patch(self, app_id): def patch(self, app_id):
"""Update an existing trace app configuration""" """Update an existing trace app configuration"""
parser = reqparse.RequestParser() parser = (
parser.add_argument("tracing_provider", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("tracing_config", type=dict, required=True, location="json") .add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -129,8 +132,7 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def delete(self, app_id): def delete(self, app_id):
"""Delete an existing trace app configuration""" """Delete an existing trace app configuration"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
parser.add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
try: try:

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -9,30 +8,36 @@ from controllers.console.wraps import account_initialization_required, setup_req
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_site_fields from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account, Site from models import Site
def parse_app_site_args(): def parse_app_site_args():
parser = reqparse.RequestParser() parser = (
parser.add_argument("title", type=str, required=False, location="json") reqparse.RequestParser()
parser.add_argument("icon_type", type=str, required=False, location="json") .add_argument("title", type=str, required=False, location="json")
parser.add_argument("icon", type=str, required=False, location="json") .add_argument("icon_type", type=str, required=False, location="json")
parser.add_argument("icon_background", type=str, required=False, location="json") .add_argument("icon", type=str, required=False, location="json")
parser.add_argument("description", type=str, required=False, location="json") .add_argument("icon_background", type=str, required=False, location="json")
parser.add_argument("default_language", type=supported_language, required=False, location="json") .add_argument("description", type=str, required=False, location="json")
parser.add_argument("chat_color_theme", type=str, required=False, location="json") .add_argument("default_language", type=supported_language, required=False, location="json")
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") .add_argument("chat_color_theme", type=str, required=False, location="json")
parser.add_argument("customize_domain", type=str, required=False, location="json") .add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
parser.add_argument("copyright", type=str, required=False, location="json") .add_argument("customize_domain", type=str, required=False, location="json")
parser.add_argument("privacy_policy", type=str, required=False, location="json") .add_argument("copyright", type=str, required=False, location="json")
parser.add_argument("custom_disclaimer", type=str, required=False, location="json") .add_argument("privacy_policy", type=str, required=False, location="json")
parser.add_argument( .add_argument("custom_disclaimer", type=str, required=False, location="json")
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json" .add_argument(
"customize_token_strategy",
type=str,
choices=["must", "allow", "not_allow"],
required=False,
location="json",
)
.add_argument("prompt_public", type=bool, required=False, location="json")
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
) )
parser.add_argument("prompt_public", type=bool, required=False, location="json")
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
return parser.parse_args() return parser.parse_args()
@ -76,9 +81,10 @@ class AppSite(Resource):
@marshal_with(app_site_fields) @marshal_with(app_site_fields)
def post(self, app_model): def post(self, app_model):
args = parse_app_site_args() args = parse_app_site_args()
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be editor, admin, or owner # The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
@ -107,8 +113,6 @@ class AppSite(Resource):
if value is not None: if value is not None:
setattr(site, attr_name, value) setattr(site, attr_name, value)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = naive_utc_now() site.updated_at = naive_utc_now()
db.session.commit() db.session.commit()
@ -131,6 +135,8 @@ class AppSiteAccessTokenReset(Resource):
@marshal_with(app_site_fields) @marshal_with(app_site_fields)
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
@ -140,8 +146,6 @@ class AppSiteAccessTokenReset(Resource):
raise NotFound raise NotFound
site.code = Site.generate_code(16) site.code = Site.generate_code(16)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = naive_utc_now() site.updated_at = naive_utc_now()
db.session.commit() db.session.commit()

View File

@ -1,10 +1,7 @@
from datetime import datetime
from decimal import Decimal from decimal import Decimal
import pytz
import sqlalchemy as sa import sqlalchemy as sa
from flask import jsonify from flask import abort, jsonify
from flask_login import current_user
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns from controllers.console import api, console_ns
@ -12,8 +9,9 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import AppMode, Message from models import AppMode, Message
@ -37,11 +35,13 @@ class DailyMessageStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -52,28 +52,19 @@ FROM
WHERE WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) try:
utc_timezone = pytz.utc start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
if args["start"]: abort(400, description=str(e))
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
if start_datetime_utc:
sql_query += " AND created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if end_datetime_utc:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
@ -89,16 +80,19 @@ WHERE
return jsonify({"data": response_data}) return jsonify({"data": response_data})
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations") @console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
class DailyConversationStatistic(Resource): class DailyConversationStatistic(Resource):
@api.doc("get_daily_conversation_statistics") @api.doc("get_daily_conversation_statistics")
@api.doc(description="Get daily conversation statistics for an application") @api.doc(description="Get daily conversation statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.expect( @api.expect(parser)
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response( @api.response(
200, 200,
"Daily conversation statistics retrieved successfully", "Daily conversation statistics retrieved successfully",
@ -109,15 +103,15 @@ class DailyConversationStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) try:
utc_timezone = pytz.utc start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
stmt = ( stmt = (
sa.select( sa.select(
@ -127,21 +121,13 @@ class DailyConversationStatistic(Resource):
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"), sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
) )
.select_from(Message) .select_from(Message)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value) .where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
) )
if args["start"]: if start_datetime_utc:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
stmt = stmt.where(Message.created_at >= start_datetime_utc) stmt = stmt.where(Message.created_at >= start_datetime_utc)
if args["end"]: if end_datetime_utc:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
stmt = stmt.where(Message.created_at < end_datetime_utc) stmt = stmt.where(Message.created_at < end_datetime_utc)
stmt = stmt.group_by("date").order_by("date") stmt = stmt.group_by("date").order_by("date")
@ -160,11 +146,7 @@ class DailyTerminalsStatistic(Resource):
@api.doc("get_daily_terminals_statistics") @api.doc("get_daily_terminals_statistics")
@api.doc(description="Get daily terminal/end-user statistics for an application") @api.doc(description="Get daily terminal/end-user statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.expect( @api.expect(parser)
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response( @api.response(
200, 200,
"Daily terminal statistics retrieved successfully", "Daily terminal statistics retrieved successfully",
@ -175,11 +157,8 @@ class DailyTerminalsStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -190,28 +169,19 @@ FROM
WHERE WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) try:
utc_timezone = pytz.utc start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
if args["start"]: abort(400, description=str(e))
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
if start_datetime_utc:
sql_query += " AND created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if end_datetime_utc:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
@ -232,11 +202,7 @@ class DailyTokenCostStatistic(Resource):
@api.doc("get_daily_token_cost_statistics") @api.doc("get_daily_token_cost_statistics")
@api.doc(description="Get daily token cost statistics for an application") @api.doc(description="Get daily token cost statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.expect( @api.expect(parser)
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response( @api.response(
200, 200,
"Daily token cost statistics retrieved successfully", "Daily token cost statistics retrieved successfully",
@ -247,11 +213,8 @@ class DailyTokenCostStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -263,28 +226,19 @@ FROM
WHERE WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) try:
utc_timezone = pytz.utc start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
if args["start"]: abort(400, description=str(e))
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
if start_datetime_utc:
sql_query += " AND created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if end_datetime_utc:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
@ -307,11 +261,7 @@ class AverageSessionInteractionStatistic(Resource):
@api.doc("get_average_session_interaction_statistics") @api.doc("get_average_session_interaction_statistics")
@api.doc(description="Get average session interaction statistics for an application") @api.doc(description="Get average session interaction statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.expect( @api.expect(parser)
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response( @api.response(
200, 200,
"Average session interaction statistics retrieved successfully", "Average session interaction statistics retrieved successfully",
@ -322,11 +272,8 @@ class AverageSessionInteractionStatistic(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -345,28 +292,19 @@ FROM
WHERE WHERE
c.app_id = :app_id c.app_id = :app_id
AND m.invoke_from != :invoke_from""" AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) try:
utc_timezone = pytz.utc start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
if args["start"]: abort(400, description=str(e))
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
if start_datetime_utc:
sql_query += " AND c.created_at >= :start" sql_query += " AND c.created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if end_datetime_utc:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND c.created_at < :end" sql_query += " AND c.created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
@ -398,11 +336,7 @@ class UserSatisfactionRateStatistic(Resource):
@api.doc("get_user_satisfaction_rate_statistics") @api.doc("get_user_satisfaction_rate_statistics")
@api.doc(description="Get user satisfaction rate statistics for an application") @api.doc(description="Get user satisfaction rate statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.expect( @api.expect(parser)
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response( @api.response(
200, 200,
"User satisfaction rate statistics retrieved successfully", "User satisfaction rate statistics retrieved successfully",
@ -413,11 +347,8 @@ class UserSatisfactionRateStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -432,28 +363,19 @@ LEFT JOIN
WHERE WHERE
m.app_id = :app_id m.app_id = :app_id
AND m.invoke_from != :invoke_from""" AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) try:
utc_timezone = pytz.utc start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
if args["start"]: abort(400, description=str(e))
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
if start_datetime_utc:
sql_query += " AND m.created_at >= :start" sql_query += " AND m.created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if end_datetime_utc:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND m.created_at < :end" sql_query += " AND m.created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
@ -479,11 +401,7 @@ class AverageResponseTimeStatistic(Resource):
@api.doc("get_average_response_time_statistics") @api.doc("get_average_response_time_statistics")
@api.doc(description="Get average response time statistics for an application") @api.doc(description="Get average response time statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.expect( @api.expect(parser)
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response( @api.response(
200, 200,
"Average response time statistics retrieved successfully", "Average response time statistics retrieved successfully",
@ -494,11 +412,8 @@ class AverageResponseTimeStatistic(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -509,28 +424,19 @@ FROM
WHERE WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) try:
utc_timezone = pytz.utc start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
if args["start"]: abort(400, description=str(e))
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
if start_datetime_utc:
sql_query += " AND created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if end_datetime_utc:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
@ -551,11 +457,7 @@ class TokensPerSecondStatistic(Resource):
@api.doc("get_tokens_per_second_statistics") @api.doc("get_tokens_per_second_statistics")
@api.doc(description="Get tokens per second statistics for an application") @api.doc(description="Get tokens per second statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.expect( @api.expect(parser)
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response( @api.response(
200, 200,
"Tokens per second statistics retrieved successfully", "Tokens per second statistics retrieved successfully",
@ -566,11 +468,7 @@ class TokensPerSecondStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -584,28 +482,19 @@ FROM
WHERE WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) try:
utc_timezone = pytz.utc start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
if args["start"]: abort(400, description=str(e))
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
if start_datetime_utc:
sql_query += " AND created_at >= :start" sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if end_datetime_utc:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc

View File

@ -12,23 +12,33 @@ import services
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginInvokeError
from core.trigger.debug.event_selectors import (
TriggerDebugEvent,
TriggerDebugEventPoller,
create_event_poller,
select_trigger_debug_events,
)
from core.workflow.enums import NodeType
from core.workflow.graph_engine.manager import GraphEngineManager from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory, variable_factory from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models import App from models import App
from models.account import Account
from models.model import AppMode from models.model import AppMode
from models.workflow import Workflow from models.workflow import Workflow
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
@ -37,6 +47,7 @@ from services.errors.llm import InvokeRateLimitError
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LISTENING_RETRY_IN = 2000
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing # TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
@ -69,15 +80,11 @@ class DraftWorkflowApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get draft workflow Get draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model # fetch draft workflow by app_model
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model) workflow = workflow_service.get_draft_workflow(app_model=app_model)
@ -106,27 +113,38 @@ class DraftWorkflowApi(Resource):
}, },
) )
) )
@api.response(200, "Draft workflow synced successfully", workflow_fields) @api.response(
200,
"Draft workflow synced successfully",
api.model(
"SyncDraftWorkflowResponse",
{
"result": fields.String,
"hash": fields.String,
"updated_at": fields.String,
},
),
)
@api.response(400, "Invalid workflow configuration") @api.response(400, "Invalid workflow configuration")
@api.response(403, "Permission denied") @api.response(403, "Permission denied")
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Sync draft workflow Sync draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant()
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type: if "application/json" in content_type:
parser = reqparse.RequestParser() parser = (
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("features", type=dict, required=True, nullable=False, location="json") .add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("hash", type=str, required=False, location="json") .add_argument("features", type=dict, required=True, nullable=False, location="json")
parser.add_argument("environment_variables", type=list, required=True, location="json") .add_argument("hash", type=str, required=False, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json") .add_argument("environment_variables", type=list, required=True, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
@ -148,10 +166,6 @@ class DraftWorkflowApi(Resource):
return {"message": "Invalid JSON data"}, 400 return {"message": "Invalid JSON data"}, 400
else: else:
abort(415) abort(415)
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService() workflow_service = WorkflowService()
try: try:
@ -205,24 +219,21 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Run draft workflow Run draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant()
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
.add_argument("inputs", type=dict, location="json")
parser = reqparse.RequestParser() .add_argument("query", type=str, required=True, location="json", default="")
parser.add_argument("inputs", type=dict, location="json") .add_argument("files", type=list, location="json")
parser.add_argument("query", type=str, required=True, location="json", default="") .add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("files", type=list, location="json") .add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json") )
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -270,18 +281,13 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow iteration node Run draft workflow iteration node
""" """
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden() parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -322,18 +328,13 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow iteration node Run draft workflow iteration node
""" """
# The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -374,19 +375,13 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow loop node Run draft workflow loop node
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -427,19 +422,13 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow loop node Run draft workflow loop node
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -479,20 +468,17 @@ class DraftWorkflowRunApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Run draft workflow Run draft workflow
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
# The role of the current user in the ta table must be admin, owner, or editor .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
if not current_user.has_edit_permission: .add_argument("files", type=list, required=False, location="json")
raise Forbidden() )
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
@ -525,17 +511,11 @@ class WorkflowTaskStopApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, task_id: str): def post(self, app_model: App, task_id: str):
""" """
Stop workflow task Stop workflow task
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# Stop using both mechanisms for backward compatibility # Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check) # Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id) AppQueueManager.set_stop_flag_no_user_check(task_id)
@ -567,21 +547,18 @@ class DraftWorkflowNodeRunApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_fields) @marshal_with(workflow_run_node_execution_fields)
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow node Run draft workflow node
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
# The role of the current user in the ta table must be admin, owner, or editor .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
if not current_user.has_edit_permission: .add_argument("query", type=str, required=False, location="json", default="")
raise Forbidden() .add_argument("files", type=list, location="json", default=[])
)
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("query", type=str, required=False, location="json", default="")
parser.add_argument("files", type=list, location="json", default=[])
args = parser.parse_args() args = parser.parse_args()
user_inputs = args.get("inputs") user_inputs = args.get("inputs")
@ -609,6 +586,13 @@ class DraftWorkflowNodeRunApi(Resource):
return workflow_node_execution return workflow_node_execution
parser_publish = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
@console_ns.route("/apps/<uuid:app_id>/workflows/publish") @console_ns.route("/apps/<uuid:app_id>/workflows/publish")
class PublishedWorkflowApi(Resource): class PublishedWorkflowApi(Resource):
@api.doc("get_published_workflow") @api.doc("get_published_workflow")
@ -621,17 +605,11 @@ class PublishedWorkflowApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get published workflow Get published workflow
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# fetch published workflow by app_model # fetch published workflow by app_model
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app_model) workflow = workflow_service.get_published_workflow(app_model=app_model)
@ -639,24 +617,19 @@ class PublishedWorkflowApi(Resource):
# return workflow, if not found, return None # return workflow, if not found, return None
return workflow return workflow
@api.expect(parser_publish)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Publish workflow Publish workflow
""" """
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser() args = parser_publish.parse_args()
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
args = parser.parse_args()
# Validate name and comment length # Validate name and comment length
if args.marked_name and len(args.marked_name) > 20: if args.marked_name and len(args.marked_name) > 20:
@ -674,8 +647,12 @@ class PublishedWorkflowApi(Resource):
marked_comment=args.marked_comment or "", marked_comment=args.marked_comment or "",
) )
app_model.workflow_id = workflow.id # Update app_model within the same session to ensure atomicity
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id app_model_in_session = session.get(App, app_model.id)
if app_model_in_session:
app_model_in_session.workflow_id = workflow.id
app_model_in_session.updated_by = current_user.id
app_model_in_session.updated_at = naive_utc_now()
workflow_created_at = TimestampField().format(workflow.created_at) workflow_created_at = TimestampField().format(workflow.created_at)
@ -697,22 +674,19 @@ class DefaultBlockConfigsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get default block config Get default block config
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs # Get default block configs
workflow_service = WorkflowService() workflow_service = WorkflowService()
return workflow_service.get_default_block_configs() return workflow_service.get_default_block_configs()
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>") @console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultBlockConfigApi(Resource): class DefaultBlockConfigApi(Resource):
@api.doc("get_default_block_config") @api.doc("get_default_block_config")
@ -720,23 +694,17 @@ class DefaultBlockConfigApi(Resource):
@api.doc(params={"app_id": "Application ID", "block_type": "Block type"}) @api.doc(params={"app_id": "Application ID", "block_type": "Block type"})
@api.response(200, "Default block configuration retrieved successfully") @api.response(200, "Default block configuration retrieved successfully")
@api.response(404, "Block type not found") @api.response(404, "Block type not found")
@api.expect(parser_block)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App, block_type: str): def get(self, app_model: App, block_type: str):
""" """
Get default block config Get default block config
""" """
if not isinstance(current_user, Account): args = parser_block.parse_args()
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args")
args = parser.parse_args()
q = args.get("q") q = args.get("q")
@ -752,8 +720,18 @@ class DefaultBlockConfigApi(Resource):
return workflow_service.get_default_block_config(node_type=block_type, filters=filters) return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
parser_convert = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow") @console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
class ConvertToWorkflowApi(Resource): class ConvertToWorkflowApi(Resource):
@api.expect(parser_convert)
@api.doc("convert_to_workflow") @api.doc("convert_to_workflow")
@api.doc(description="Convert application to workflow mode") @api.doc(description="Convert application to workflow mode")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@ -764,25 +742,17 @@ class ConvertToWorkflowApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION]) @get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Convert basic mode of chatbot app to workflow mode Convert basic mode of chatbot app to workflow mode
Convert expert mode of chatbot app to workflow mode Convert expert mode of chatbot app to workflow mode
Convert Completion App to Workflow App Convert Completion App to Workflow App
""" """
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
if request.data: if request.data:
parser = reqparse.RequestParser() args = parser_convert.parse_args()
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
else: else:
args = {} args = {}
@ -796,8 +766,18 @@ class ConvertToWorkflowApi(Resource):
} }
parser_workflows = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/apps/<uuid:app_id>/workflows") @console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource): class PublishedAllWorkflowApi(Resource):
@api.expect(parser_workflows)
@api.doc("get_all_published_workflows") @api.doc("get_all_published_workflows")
@api.doc(description="Get all published workflows for an application") @api.doc(description="Get all published workflows for an application")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@ -807,24 +787,16 @@ class PublishedAllWorkflowApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_pagination_fields) @marshal_with(workflow_pagination_fields)
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get published workflows Get published workflows
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): args = parser_workflows.parse_args()
raise Forbidden() page = args["page"]
if not current_user.has_edit_permission: limit = args["limit"]
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("user_id", type=str, required=False, location="args")
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
args = parser.parse_args()
page = int(args.get("page", 1))
limit = int(args.get("limit", 10))
user_id = args.get("user_id") user_id = args.get("user_id")
named_only = args.get("named_only", False) named_only = args.get("named_only", False)
@ -874,19 +846,17 @@ class WorkflowByIdApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
@edit_permission_required
def patch(self, app_model: App, workflow_id: str): def patch(self, app_model: App, workflow_id: str):
""" """
Update workflow attributes Update workflow attributes
""" """
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
# Check permission reqparse.RequestParser()
if not current_user.has_edit_permission: .add_argument("marked_name", type=str, required=False, location="json")
raise Forbidden() .add_argument("marked_comment", type=str, required=False, location="json")
)
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, location="json")
parser.add_argument("marked_comment", type=str, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
# Validate name and comment length # Validate name and comment length
@ -929,16 +899,11 @@ class WorkflowByIdApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def delete(self, app_model: App, workflow_id: str): def delete(self, app_model: App, workflow_id: str):
""" """
Delete workflow Delete workflow
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission
if not current_user.has_edit_permission:
raise Forbidden()
workflow_service = WorkflowService() workflow_service = WorkflowService()
# Create a session and manage the transaction # Create a session and manage the transaction
@ -985,3 +950,234 @@ class DraftWorkflowNodeLastRunApi(Resource):
if node_exec is None: if node_exec is None:
raise NotFound("last run not found") raise NotFound("last run not found")
return node_exec return node_exec
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run")
class DraftWorkflowTriggerRunApi(Resource):
"""
Full workflow debug - Polling API for trigger events
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
"""
@api.doc("poll_draft_workflow_trigger_run")
@api.doc(description="Poll for trigger events and execute full workflow when event arrives")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"DraftWorkflowTriggerRunRequest",
{
"node_id": fields.String(required=True, description="Node ID"),
},
)
)
@api.response(200, "Trigger event received and workflow executed successfully")
@api.response(403, "Permission denied")
@api.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Poll for trigger events and execute full workflow when event arrives
"""
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
args = parser.parse_args()
node_id = args["node_id"]
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
raise ValueError("Workflow not found")
poller: TriggerDebugEventPoller = create_event_poller(
draft_workflow=draft_workflow,
tenant_id=app_model.tenant_id,
user_id=current_user.id,
app_id=app_model.id,
node_id=node_id,
)
event: TriggerDebugEvent | None = None
try:
event = poller.poll()
if not event:
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
workflow_args = dict(event.workflow_args)
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
return helper.compact_generate_response(
AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=workflow_args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
root_node_id=node_id,
)
)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except PluginInvokeError as e:
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
except Exception as e:
logger.exception("Error polling trigger debug event")
raise e
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run")
class DraftWorkflowTriggerNodeApi(Resource):
"""
Single node debug - Polling API for trigger events
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run
"""
@api.doc("poll_draft_workflow_trigger_node")
@api.doc(description="Poll for trigger events and execute single node when event arrives")
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@api.response(200, "Trigger event received and node executed successfully")
@api.response(403, "Permission denied")
@api.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Poll for trigger events and execute single node when event arrives
"""
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
raise ValueError("Workflow not found")
node_config = draft_workflow.get_node_config_by_id(node_id=node_id)
if not node_config:
raise ValueError("Node data not found for node %s", node_id)
node_type: NodeType = draft_workflow.get_node_type_from_node_config(node_config)
event: TriggerDebugEvent | None = None
# for schedule trigger, when run single node, just execute directly
if node_type == NodeType.TRIGGER_SCHEDULE:
event = TriggerDebugEvent(
workflow_args={},
node_id=node_id,
)
# for other trigger types, poll for the event
else:
try:
poller: TriggerDebugEventPoller = create_event_poller(
draft_workflow=draft_workflow,
tenant_id=app_model.tenant_id,
user_id=current_user.id,
app_id=app_model.id,
node_id=node_id,
)
event = poller.poll()
except PluginInvokeError as e:
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
except Exception as e:
logger.exception("Error polling trigger debug event")
raise e
if not event:
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
raw_files = event.workflow_args.get("files")
files = _parse_file(draft_workflow, raw_files if isinstance(raw_files, list) else None)
try:
node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model,
draft_workflow=draft_workflow,
node_id=node_id,
user_inputs=event.workflow_args.get("inputs") or {},
account=current_user,
query="",
files=files,
)
return jsonable_encoder(node_execution)
except Exception as e:
logger.exception("Error running draft workflow trigger node")
return jsonable_encoder(
{"status": "error", "error": "An unexpected error occurred while running the node."}
), 400
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run-all")
class DraftWorkflowTriggerRunAllApi(Resource):
"""
Full workflow debug - Polling API for trigger events
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run-all
"""
@api.doc("draft_workflow_trigger_run_all")
@api.doc(description="Full workflow debug when the start node is a trigger")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"DraftWorkflowTriggerRunAllRequest",
{
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
},
)
)
@api.response(200, "Workflow executed successfully")
@api.response(403, "Permission denied")
@api.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Full workflow debug when the start node is a trigger
"""
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("node_ids", type=list, required=True, location="json", nullable=False)
args = parser.parse_args()
node_ids = args["node_ids"]
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
raise ValueError("Workflow not found")
try:
trigger_debug_event: TriggerDebugEvent | None = select_trigger_debug_events(
draft_workflow=draft_workflow,
app_model=app_model,
user_id=current_user.id,
node_ids=node_ids,
)
except PluginInvokeError as e:
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
except Exception as e:
logger.exception("Error polling trigger debug event")
raise e
if trigger_debug_event is None:
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
try:
workflow_args = dict(trigger_debug_event.workflow_args)
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=workflow_args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
root_node_id=trigger_debug_event.node_id,
)
return helper.compact_generate_response(response)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except Exception:
logger.exception("Error running draft workflow trigger run-all")
return jsonable_encoder(
{
"status": "error",
}
), 400

View File

@ -28,6 +28,7 @@ class WorkflowAppLogApi(Resource):
"created_at__after": "Filter logs created after this timestamp", "created_at__after": "Filter logs created after this timestamp",
"created_by_end_user_session_id": "Filter by end user session ID", "created_by_end_user_session_id": "Filter by end user session ID",
"created_by_account": "Filter by account", "created_by_account": "Filter by account",
"detail": "Whether to return detailed logs",
"page": "Page number (1-99999)", "page": "Page number (1-99999)",
"limit": "Number of items per page (1-100)", "limit": "Number of items per page (1-100)",
} }
@ -42,33 +43,36 @@ class WorkflowAppLogApi(Resource):
""" """
Get workflow app logs Get workflow app logs
""" """
parser = reqparse.RequestParser() parser = (
parser.add_argument("keyword", type=str, location="args") reqparse.RequestParser()
parser.add_argument( .add_argument("keyword", type=str, location="args")
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args" .add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("detail", type=bool, location="args", required=False, default=False)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
) )
parser.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
parser.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
parser.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
args.status = WorkflowExecutionStatus(args.status) if args.status else None args.status = WorkflowExecutionStatus(args.status) if args.status else None
@ -90,6 +94,7 @@ class WorkflowAppLogApi(Resource):
created_at_after=args.created_at__after, created_at_after=args.created_at__after,
page=args.page, page=args.page,
limit=args.limit, limit=args.limit,
detail=args.detail,
created_by_end_user_session_id=args.created_by_end_user_session_id, created_by_end_user_session_id=args.created_by_end_user_session_id,
created_by_account=args.created_by_account, created_by_account=args.created_by_account,
) )

View File

@ -22,8 +22,7 @@ from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models import App, AppMode from models import Account, App, AppMode
from models.account import Account
from models.workflow import WorkflowDraftVariable from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@ -58,16 +57,18 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
def _create_pagination_parser(): def _create_pagination_parser():
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"page", .add_argument(
type=inputs.int_range(1, 100_000), "page",
required=False, type=inputs.int_range(1, 100_000),
default=1, required=False,
location="args", default=1,
help="the page of data requested", location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
) )
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser return parser
@ -320,10 +321,11 @@ class VariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# } # }
parser = reqparse.RequestParser() parser = (
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") reqparse.RequestParser()
# Parse 'value' field as-is to maintain its original data structure .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=db.session(), session=db.session(),

View File

@ -1,6 +1,5 @@
from typing import cast from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
@ -9,15 +8,85 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.workflow_run_fields import ( from fields.workflow_run_fields import (
advanced_chat_workflow_run_pagination_fields, advanced_chat_workflow_run_pagination_fields,
workflow_run_count_fields,
workflow_run_detail_fields, workflow_run_detail_fields,
workflow_run_node_execution_list_fields, workflow_run_node_execution_list_fields,
workflow_run_pagination_fields, workflow_run_pagination_fields,
) )
from libs.custom_inputs import time_duration
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import login_required from libs.login import current_user, login_required
from models import Account, App, AppMode, EndUser from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
from services.workflow_run_service import WorkflowRunService from services.workflow_run_service import WorkflowRunService
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
def _parse_workflow_run_list_args():
"""
Parse common arguments for workflow run list endpoints.
Returns:
Parsed arguments containing last_id, limit, status, and triggered_from filters
"""
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
)
return parser.parse_args()
def _parse_workflow_run_count_args():
"""
Parse common arguments for workflow run count endpoints.
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = (
reqparse.RequestParser()
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
)
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
class AdvancedChatAppWorkflowRunListApi(Resource): class AdvancedChatAppWorkflowRunListApi(Resource):
@ -25,6 +94,8 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@api.doc(description="Get advanced chat workflow run list") @api.doc(description="Get advanced chat workflow run list")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields) @api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@ -35,13 +106,64 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
""" """
Get advanced chat app workflow run list Get advanced chat app workflow run list
""" """
parser = reqparse.RequestParser() args = _parse_workflow_run_list_args()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") # Default to DEBUGGING if not specified
args = parser.parse_args() triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService() workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args) result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
app_model=app_model, args=args, triggered_from=triggered_from
)
return result
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
class AdvancedChatAppWorkflowRunCountApi(Resource):
@api.doc("get_advanced_chat_workflow_runs_count")
@api.doc(description="Get advanced chat workflow runs count statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
)
}
)
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@marshal_with(workflow_run_count_fields)
def get(self, app_model: App):
"""
Get advanced chat workflow runs count statistics
"""
args = _parse_workflow_run_count_args()
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_runs_count(
app_model=app_model,
status=args.get("status"),
time_range=args.get("time_range"),
triggered_from=triggered_from,
)
return result return result
@ -52,6 +174,8 @@ class WorkflowRunListApi(Resource):
@api.doc(description="Get workflow run list") @api.doc(description="Get workflow run list")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields) @api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@ -62,13 +186,64 @@ class WorkflowRunListApi(Resource):
""" """
Get workflow run list Get workflow run list
""" """
parser = reqparse.RequestParser() args = _parse_workflow_run_list_args()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") # Default to DEBUGGING for workflow if not specified (backward compatibility)
args = parser.parse_args() triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService() workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args) result = workflow_run_service.get_paginate_workflow_runs(
app_model=app_model, args=args, triggered_from=triggered_from
)
return result
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
class WorkflowRunCountApi(Resource):
@api.doc("get_workflow_runs_count")
@api.doc(description="Get workflow runs count statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
)
}
)
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_count_fields)
def get(self, app_model: App):
"""
Get workflow runs count statistics
"""
args = _parse_workflow_run_count_args()
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_runs_count(
app_model=app_model,
status=args.get("status"),
time_range=args.get("time_range"),
triggered_from=triggered_from,
)
return result return result

View File

@ -1,24 +1,26 @@
from datetime import datetime from flask import abort, jsonify
from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from sqlalchemy.orm import sessionmaker
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations") @console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
class WorkflowDailyRunsStatistic(Resource): class WorkflowDailyRunsStatistic(Resource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_daily_runs_statistic") @api.doc("get_workflow_daily_runs_statistic")
@api.doc(description="Get workflow daily runs statistics") @api.doc(description="Get workflow daily runs statistics")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@ -29,64 +31,41 @@ class WorkflowDailyRunsStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT assert account.timezone is not None
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(id) AS runs
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
timezone = pytz.timezone(account.timezone) try:
utc_timezone = pytz.utc start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if args["start"]: response_data = self._workflow_run_repo.get_daily_runs_statistics(
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") tenant_id=app_model.tenant_id,
start_datetime = start_datetime.replace(second=0) app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_datetime_timezone = timezone.localize(start_datetime) start_date=start_date,
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) end_date=end_date,
timezone=account.timezone,
sql_query += " AND created_at >= :start" )
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "runs": i.runs})
return jsonify({"data": response_data}) return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals") @console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
class WorkflowDailyTerminalsStatistic(Resource): class WorkflowDailyTerminalsStatistic(Resource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_daily_terminals_statistic") @api.doc("get_workflow_daily_terminals_statistic")
@api.doc(description="Get workflow daily terminals statistics") @api.doc(description="Get workflow daily terminals statistics")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@ -97,64 +76,41 @@ class WorkflowDailyTerminalsStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT assert account.timezone is not None
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
timezone = pytz.timezone(account.timezone) try:
utc_timezone = pytz.utc start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if args["start"]: response_data = self._workflow_run_repo.get_daily_terminals_statistics(
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") tenant_id=app_model.tenant_id,
start_datetime = start_datetime.replace(second=0) app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_datetime_timezone = timezone.localize(start_datetime) start_date=start_date,
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) end_date=end_date,
timezone=account.timezone,
sql_query += " AND created_at >= :start" )
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
return jsonify({"data": response_data}) return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs") @console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs")
class WorkflowDailyTokenCostStatistic(Resource): class WorkflowDailyTokenCostStatistic(Resource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_daily_token_cost_statistic") @api.doc("get_workflow_daily_token_cost_statistic")
@api.doc(description="Get workflow daily token cost statistics") @api.doc(description="Get workflow daily token cost statistics")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@ -165,69 +121,41 @@ class WorkflowDailyTokenCostStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT assert account.timezone is not None
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
SUM(workflow_runs.total_tokens) AS token_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
timezone = pytz.timezone(account.timezone) try:
utc_timezone = pytz.utc start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if args["start"]: response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") tenant_id=app_model.tenant_id,
start_datetime = start_datetime.replace(second=0) app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_datetime_timezone = timezone.localize(start_datetime) start_date=start_date,
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) end_date=end_date,
timezone=account.timezone,
sql_query += " AND created_at >= :start" )
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{
"date": str(i.date),
"token_count": i.token_count,
}
)
return jsonify({"data": response_data}) return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions") @console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions")
class WorkflowAverageAppInteractionStatistic(Resource): class WorkflowAverageAppInteractionStatistic(Resource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_average_app_interaction_statistic") @api.doc("get_workflow_average_app_interaction_statistic")
@api.doc(description="Get workflow average app interaction statistics") @api.doc(description="Get workflow average app interaction statistics")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@ -238,74 +166,29 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.WORKFLOW])
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT assert account.timezone is not None
AVG(sub.interactions) AS interactions,
sub.date
FROM
(
SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
c.created_by,
COUNT(c.id) AS interactions
FROM
workflow_runs c
WHERE
c.app_id = :app_id
AND c.triggered_from = :triggered_from
{{start}}
{{end}}
GROUP BY
date, c.created_by
) sub
GROUP BY
sub.date"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
timezone = pytz.timezone(account.timezone) try:
utc_timezone = pytz.utc start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if args["start"]: response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") tenant_id=app_model.tenant_id,
start_datetime = start_datetime.replace(second=0) app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_datetime_timezone = timezone.localize(start_datetime) start_date=start_date,
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) end_date=end_date,
timezone=account.timezone,
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start") )
arg_dict["start"] = start_datetime_utc
else:
sql_query = sql_query.replace("{{start}}", "")
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
arg_dict["end"] = end_datetime_utc
else:
sql_query = sql_query.replace("{{end}}", "")
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
)
return jsonify({"data": response_data}) return jsonify({"data": response_data})

View File

@ -0,0 +1,145 @@
import logging
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
from models.enums import AppTriggerStatus
from models.model import Account, App, AppMode
from models.trigger import AppTrigger, WorkflowWebhookTrigger
logger = logging.getLogger(__name__)
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(webhook_trigger_fields)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
args = parser.parse_args()
node_id = str(args["node_id"])
with Session(db.engine) as session:
# Get webhook trigger for this app and node
webhook_trigger = (
session.query(WorkflowWebhookTrigger)
.where(
WorkflowWebhookTrigger.app_id == app_model.id,
WorkflowWebhookTrigger.node_id == node_id,
)
.first()
)
if not webhook_trigger:
raise NotFound("Webhook trigger not found for this node")
return webhook_trigger
class AppTriggersApi(Resource):
"""App Triggers list API"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(triggers_list_fields)
def get(self, app_model: App):
"""Get app triggers list"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
with Session(db.engine) as session:
# Get all triggers for this app using select API
triggers = (
session.execute(
select(AppTrigger)
.where(
AppTrigger.tenant_id == current_user.current_tenant_id,
AppTrigger.app_id == app_model.id,
)
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
)
.scalars()
.all()
)
# Add computed icon field for each trigger
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
for trigger in triggers:
if trigger.trigger_type == "trigger-plugin":
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
else:
trigger.icon = "" # type: ignore
return {"data": triggers}
class AppTriggerEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_fields)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
parser = reqparse.RequestParser()
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if not current_user.has_edit_permission:
raise Forbidden()
trigger_id = args["trigger_id"]
with Session(db.engine) as session:
# Find the trigger using select
trigger = session.execute(
select(AppTrigger).where(
AppTrigger.id == trigger_id,
AppTrigger.tenant_id == current_user.current_tenant_id,
AppTrigger.app_id == app_model.id,
)
).scalar_one_or_none()
if not trigger:
raise NotFound("Trigger not found")
# Update status based on enable_trigger boolean
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)
# Add computed icon field
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
if trigger.trigger_type == "trigger-plugin":
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
else:
trigger.icon = "" # type: ignore
return trigger
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")

View File

@ -4,28 +4,29 @@ from typing import ParamSpec, TypeVar, Union
from controllers.console.app.error import AppNotFoundError from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user from libs.login import current_account_with_tenant
from models import App, AppMode from models import App, AppMode
from models.account import Account
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
P1 = ParamSpec("P1")
R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None: def _load_app_model(app_id: str) -> App | None:
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
app_model = ( app_model = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
return app_model return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None): def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]): def decorator(view_func: Callable[P1, R1]):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs): def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
if not kwargs.get("app_id"): if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters") raise ValueError("missing app_id in path parameters")

View File

@ -7,18 +7,14 @@ from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus from models import AccountStatus
from services.account_service import AccountService, RegisterService from services.account_service import AccountService, RegisterService
active_check_parser = reqparse.RequestParser() active_check_parser = (
active_check_parser.add_argument( reqparse.RequestParser()
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID" .add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
) .add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
active_check_parser.add_argument( .add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
"email", type=email, required=False, nullable=True, location="args", help="Email address"
)
active_check_parser.add_argument(
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
) )
@ -60,15 +56,15 @@ class ActivateCheckApi(Resource):
return {"is_valid": False} return {"is_valid": False}
active_parser = reqparse.RequestParser() active_parser = (
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") reqparse.RequestParser()
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json") .add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("email", type=email, required=False, nullable=True, location="json")
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") .add_argument("token", type=str, required=True, nullable=False, location="json")
active_parser.add_argument( .add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
"interface_language", type=supported_language, required=True, nullable=False, location="json" .add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
) )
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
@console_ns.route("/activate") @console_ns.route("/activate")
@ -103,7 +99,7 @@ class ActivateApi(Resource):
account.interface_language = args["interface_language"] account.interface_language = args["interface_language"]
account.timezone = args["timezone"] account.timezone = args["timezone"]
account.interface_theme = "light" account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now() account.initialized_at = naive_utc_now()
db.session.commit() db.session.commit()

View File

@ -1,10 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required from ..wraps import account_initialization_required, setup_required
@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) _, current_tenant_id = current_account_with_tenant()
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
if data_source_api_key_bindings: if data_source_api_key_bindings:
return { return {
"sources": [ "sources": [
@ -41,16 +41,20 @@ class ApiKeyAuthDataSourceBinding(Resource):
@account_initialization_required @account_initialization_required
def post(self): def post(self):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("category", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="json") .add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") .add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args) ApiKeyAuthService.validate_api_key_auth_args(args)
try: try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
except Exception as e: except Exception as e:
raise ApiKeyAuthFailedError(str(e)) raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -63,9 +67,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@account_initialization_required @account_initialization_required
def delete(self, binding_id): def delete(self, binding_id):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -2,13 +2,12 @@ import logging
import httpx import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api, console_ns from controllers.console import api, console_ns
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required from ..wraps import account_initialization_required, setup_required
@ -45,6 +44,7 @@ class OAuthDataSource(Resource):
@api.response(403, "Admin privileges required") @api.response(403, "Admin privileges required")
def get(self, provider: str): def get(self, provider: str):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()

View File

@ -19,7 +19,7 @@ from controllers.console.wraps import email_password_login_enabled, email_regist
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.account import Account from models import Account
from services.account_service import AccountService from services.account_service import AccountService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.account import AccountNotFoundError, AccountRegisterError
@ -31,9 +31,11 @@ class EmailRegisterSendEmailApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
@ -59,10 +61,12 @@ class EmailRegisterCheckApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") .add_argument("email", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
@ -100,10 +104,12 @@ class EmailRegisterResetApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("token", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") .add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# Validate passwords match # Validate passwords match

View File

@ -20,7 +20,7 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models.account import Account from models import Account
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -54,9 +54,11 @@ class ForgotPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
@ -111,10 +113,12 @@ class ForgotPasswordCheckApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") .add_argument("email", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
@ -169,10 +173,12 @@ class ForgotPasswordResetApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("token", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") .add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# Validate passwords match # Validate passwords match

View File

@ -1,12 +1,10 @@
from typing import cast
import flask_login import flask_login
from flask import request from flask import make_response, request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
import services import services
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import get_valid_language
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
AuthenticationFailedError, AuthenticationFailedError,
@ -26,7 +24,16 @@ from controllers.console.error import (
from controllers.console.wraps import email_password_login_enabled, setup_required from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
from models.account import Account from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie,
extract_refresh_token,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import AccountRegisterError from services.errors.account import AccountRegisterError
@ -42,11 +49,13 @@ class LoginApi(Resource):
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
"""Authenticate user and login.""" """Authenticate user and login."""
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("password", type=str, required=True, location="json") .add_argument("email", type=email, required=True, location="json")
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") .add_argument("password", type=str, required=True, location="json")
parser.add_argument("invite_token", type=str, required=False, default=None, location="json") .add_argument("remember_me", type=bool, required=False, default=False, location="json")
.add_argument("invite_token", type=str, required=False, default=None, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
@ -89,19 +98,36 @@ class LoginApi(Resource):
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
@console_ns.route("/logout") @console_ns.route("/logout")
class LogoutApi(Resource): class LogoutApi(Resource):
@setup_required @setup_required
def get(self): def post(self):
account = cast(Account, flask_login.current_user) current_user, _ = current_account_with_tenant()
account = current_user
if isinstance(account, flask_login.AnonymousUserMixin): if isinstance(account, flask_login.AnonymousUserMixin):
return {"result": "success"} response = make_response({"result": "success"})
AccountService.logout(account=account) else:
flask_login.logout_user() AccountService.logout(account=account)
return {"result": "success"} flask_login.logout_user()
response = make_response({"result": "success"})
# Clear cookies on logout
clear_access_token_from_cookie(response)
clear_refresh_token_from_cookie(response)
clear_csrf_token_from_cookie(response)
return response
@console_ns.route("/reset-password") @console_ns.route("/reset-password")
@ -109,9 +135,11 @@ class ResetPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if args["language"] is not None and args["language"] == "zh-Hans": if args["language"] is not None and args["language"] == "zh-Hans":
@ -137,9 +165,11 @@ class ResetPasswordSendEmailApi(Resource):
class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginSendEmailApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
@ -170,13 +200,17 @@ class EmailCodeLoginSendEmailApi(Resource):
class EmailCodeLoginApi(Resource): class EmailCodeLoginApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") .add_argument("email", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
language = args["language"]
token_data = AccountService.get_email_code_login_data(args["token"]) token_data = AccountService.get_email_code_login_data(args["token"])
if token_data is None: if token_data is None:
@ -210,7 +244,9 @@ class EmailCodeLoginApi(Resource):
if account is None: if account is None:
try: try:
account = AccountService.create_account_and_tenant( account = AccountService.create_account_and_tenant(
email=user_email, name=user_email, interface_language=languages[0] email=user_email,
name=user_email,
interface_language=get_valid_language(language),
) )
except WorkSpaceNotAllowedCreateError: except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace() raise NotAllowedCreateWorkspace()
@ -220,18 +256,36 @@ class EmailCodeLoginApi(Resource):
raise WorkspacesLimitExceeded() raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
# Set HTTP-only secure cookies for tokens
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
return response
@console_ns.route("/refresh-token") @console_ns.route("/refresh-token")
class RefreshTokenApi(Resource): class RefreshTokenApi(Resource):
def post(self): def post(self):
parser = reqparse.RequestParser() # Get refresh token from cookie instead of request body
parser.add_argument("refresh_token", type=str, required=True, location="json") refresh_token = extract_refresh_token(request)
args = parser.parse_args()
if not refresh_token:
return {"result": "fail", "message": "No refresh token provided"}, 401
try: try:
new_token_pair = AccountService.refresh_token(args["refresh_token"]) new_token_pair = AccountService.refresh_token(refresh_token)
return {"result": "success", "data": new_token_pair.model_dump()}
# Create response with new cookies
response = make_response({"result": "success"})
# Update cookies with new tokens
set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token)
set_access_token_to_cookie(request, response, new_token_pair.access_token)
set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token)
return response
except Exception as e: except Exception as e:
return {"result": "fail", "data": str(e)}, 401 return {"result": "fail", "message": str(e)}, 401

View File

@ -14,8 +14,12 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account from libs.token import (
from models.account import AccountStatus set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from models import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.account import AccountNotFoundError, AccountRegisterError
@ -130,11 +134,11 @@ class OAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}") return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
# Check account status # Check account status
if account.status == AccountStatus.BANNED.value: if account.status == AccountStatus.BANNED:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.") return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING:
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now() account.initialized_at = naive_utc_now()
db.session.commit() db.session.commit()
@ -153,9 +157,12 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request), ip_address=extract_remote_ip(request),
) )
return redirect( response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
) set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None: def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:

View File

@ -1,16 +1,15 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast from typing import Concatenate, ParamSpec, TypeVar
import flask_login
from flask import jsonify, request from flask import jsonify, request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account from models import Account
from models.model import OAuthProviderApp from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
@ -24,8 +23,7 @@ T = TypeVar("T")
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view) @wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs): def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
client_id = parsed_args.get("client_id") client_id = parsed_args.get("client_id")
if not client_id: if not client_id:
@ -91,8 +89,7 @@ class OAuthServerAppApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
parser.add_argument("redirect_uri", type=str, required=True, location="json")
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
redirect_uri = parsed_args.get("redirect_uri") redirect_uri = parsed_args.get("redirect_uri")
@ -116,7 +113,8 @@ class OAuthServerUserAuthorizeApi(Resource):
@account_initialization_required @account_initialization_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
account = cast(Account, flask_login.current_user) current_user, _ = current_account_with_tenant()
account = current_user
user_account_id = account.id user_account_id = account.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id) code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
@ -132,12 +130,14 @@ class OAuthServerUserTokenApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser() parser = (
parser.add_argument("grant_type", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=False, location="json") .add_argument("grant_type", type=str, required=True, location="json")
parser.add_argument("client_secret", type=str, required=False, location="json") .add_argument("code", type=str, required=False, location="json")
parser.add_argument("redirect_uri", type=str, required=False, location="json") .add_argument("client_secret", type=str, required=False, location="json")
parser.add_argument("refresh_token", type=str, required=False, location="json") .add_argument("redirect_uri", type=str, required=False, location="json")
.add_argument("refresh_token", type=str, required=False, location="json")
)
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
try: try:

View File

@ -2,8 +2,8 @@ from flask_restx import Resource, reqparse
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_user, login_required from enums.cloud_plan import CloudPlan
from models.model import Account from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService from services.billing_service import BillingService
@ -14,17 +14,21 @@ class Subscription(Resource):
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
parser = reqparse.RequestParser() current_user, current_tenant_id = current_account_with_tenant()
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) parser = (
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) reqparse.RequestParser()
args = parser.parse_args() .add_argument(
assert isinstance(current_user, Account) "plan",
type=str,
BillingService.is_tenant_owner_or_admin(current_user) required=True,
assert current_user.current_tenant_id is not None location="args",
return BillingService.get_subscription( choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id )
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
) )
args = parser.parse_args()
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
@console_ns.route("/billing/invoices") @console_ns.route("/billing/invoices")
@ -34,7 +38,6 @@ class Invoices(Resource):
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None return BillingService.get_invoices(current_user.email, current_tenant_id)
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)

View File

@ -1,9 +1,8 @@
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService from services.billing_service import BillingService
from .. import console_ns from .. import console_ns
@ -17,17 +16,16 @@ class ComplianceApi(Resource):
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
parser = reqparse.RequestParser() current_user, current_tenant_id = current_account_with_tenant()
parser.add_argument("doc_name", type=str, required=True, location="args") parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device") device_info = request.headers.get("User-Agent", "Unknown device")
return BillingService.get_compliance_download_link( return BillingService.get_compliance_download_link(
doc_name=args.doc_name, doc_name=args.doc_name,
account_id=current_user.id, account_id=current_user.id,
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
ip=ip_address, ip=ip_address,
device_info=device_info, device_info=device_info,
) )

View File

@ -3,7 +3,6 @@ from collections.abc import Generator
from typing import cast from typing import cast
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -15,12 +14,12 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType,
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
@ -37,10 +36,12 @@ class DataSourceApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(integrate_list_fields) @marshal_with(integrate_list_fields)
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
# get workspace data source integrates # get workspace data source integrates
data_source_integrates = db.session.scalars( data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where( select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_tenant_id,
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
) )
).all() ).all()
@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(integrate_notion_info_list_fields) @marshal_with(integrate_notion_info_list_fields)
def get(self): def get(self):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = request.args.get("dataset_id", default=None, type=str) dataset_id = request.args.get("dataset_id", default=None, type=str)
credential_id = request.args.get("credential_id", default=None, type=str) credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id: if not credential_id:
raise ValueError("Credential id is required.") raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials( credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
credential_id=credential_id, credential_id=credential_id,
provider="notion_datasource", provider="notion_datasource",
plugin_id="langgenius/notion_datasource", plugin_id="langgenius/notion_datasource",
@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource):
documents = session.scalars( documents = session.scalars(
select(Document).filter_by( select(Document).filter_by(
dataset_id=dataset_id, dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
data_source_type="notion_import", data_source_type="notion_import",
enabled=True, enabled=True,
) )
@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource):
datasource_runtime = DatasourceManager.get_datasource_runtime( datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion_datasource", provider_id="langgenius/notion_datasource/notion_datasource",
datasource_name="notion_datasource", datasource_name="notion_datasource",
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
) )
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, workspace_id, page_id, page_type): def get(self, workspace_id, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str) credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id: if not credential_id:
raise ValueError("Credential id is required.") raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials( credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
credential_id=credential_id, credential_id=credential_id,
provider="notion_datasource", provider="notion_datasource",
plugin_id="langgenius/notion_datasource", plugin_id="langgenius/notion_datasource",
@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource):
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type, notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"), notion_access_token=credential.get("integration_secret"),
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
) )
text_docs = extractor.extract() text_docs = extractor.extract()
@ -239,12 +244,14 @@ class DataSourceNotionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() _, current_tenant_id = current_account_with_tenant()
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") parser = (
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
"doc_language", type=str, default="English", required=False, nullable=False, location="json" .add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
) )
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
@ -256,20 +263,22 @@ class DataSourceNotionApi(Resource):
credential_id = notion_info.get("credential_id") credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]: for page in notion_info["pages"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION,
notion_info={ notion_info=NotionInfo.model_validate(
"credential_id": credential_id, {
"notion_workspace_id": workspace_id, "credential_id": credential_id,
"notion_obj_id": page["page_id"], "notion_workspace_id": workspace_id,
"notion_page_type": page["type"], "notion_obj_id": page["page_id"],
"tenant_id": current_user.current_tenant_id, "notion_page_type": page["type"],
}, "tenant_id": current_tenant_id,
}
),
document_model=args["doc_form"], document_model=args["doc_form"],
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate( response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_tenant_id,
extract_settings, extract_settings,
args["process_rule"], args["process_rule"],
args["doc_form"], args["doc_form"],

View File

@ -1,6 +1,6 @@
import flask_restx from typing import Any, cast
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -23,29 +23,97 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import related_app_list from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields from fields.document_fields import document_status_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
def _validate_name(name): def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40: if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.") raise ValueError("Name must be between 1 to 40 characters.")
return name return name
def _validate_description_length(description): def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
if description and len(description) > 400: """
raise ValueError("Description cannot exceed 400 characters.") Get supported retrieval methods based on vector database type.
return description
Args:
vector_type: Vector database type, can be None
is_mock: Whether this is a Mock API, affects MILVUS handling
Returns:
Dictionary containing supported retrieval methods
Raises:
ValueError: If vector_type is None or unsupported
"""
if vector_type is None:
raise ValueError("Vector store type is not configured.")
# Define vector database types that only support semantic search
semantic_only_types = {
VectorType.RELYT,
VectorType.TIDB_VECTOR,
VectorType.CHROMA,
VectorType.PGVECTO_RS,
VectorType.VIKINGDB,
VectorType.UPSTASH,
}
# Define vector database types that support all retrieval methods
full_search_types = {
VectorType.QDRANT,
VectorType.WEAVIATE,
VectorType.OPENSEARCH,
VectorType.ANALYTICDB,
VectorType.MYSCALE,
VectorType.ORACLE,
VectorType.ELASTICSEARCH,
VectorType.ELASTICSEARCH_JA,
VectorType.PGVECTOR,
VectorType.VASTBASE,
VectorType.TIDB_ON_QDRANT,
VectorType.LINDORM,
VectorType.COUCHBASE,
VectorType.OPENGAUSS,
VectorType.OCEANBASE,
VectorType.TABLESTORE,
VectorType.HUAWEI_CLOUD,
VectorType.TENCENT,
VectorType.MATRIXONE,
VectorType.CLICKZETTA,
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
full_methods = {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]
}
if vector_type == VectorType.MILVUS:
return semantic_methods if is_mock else full_methods
if vector_type in semantic_only_types:
return semantic_methods
elif vector_type in full_search_types:
return full_methods
else:
raise ValueError(f"Unsupported vector db type {vector_type}.")
@console_ns.route("/datasets") @console_ns.route("/datasets")
@ -68,6 +136,7 @@ class DatasetListApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def get(self): def get(self):
current_user, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids") ids = request.args.getlist("ids")
@ -76,15 +145,15 @@ class DatasetListApi(Resource):
tag_ids = request.args.getlist("tag_ids") tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true" include_all = request.args.get("include_all", default="false").lower() == "true"
if ids: if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
else: else:
datasets, total = DatasetService.get_datasets( datasets, total = DatasetService.get_datasets(
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all page, limit, current_tenant_id, current_user, search, tag_ids, include_all
) )
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@ -92,7 +161,7 @@ class DatasetListApi(Resource):
for embedding_model in embedding_models: for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields) data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
for item in data: for item in data:
# convert embedding_model_provider to plugin standard format # convert embedding_model_provider to plugin standard format
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
@ -137,50 +206,53 @@ class DatasetListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="type is required. Name must be between 1 to 40 characters.", required=True,
type=_validate_name, help="type is required. Name must be between 1 to 40 characters.",
) type=_validate_name,
parser.add_argument( )
"description", .add_argument(
type=_validate_description_length, "description",
nullable=True, type=validate_description_length,
required=False, nullable=True,
default="", required=False,
) default="",
parser.add_argument( )
"indexing_technique", .add_argument(
type=str, "indexing_technique",
location="json", type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST, location="json",
nullable=True, choices=Dataset.INDEXING_TECHNIQUE_LIST,
help="Invalid indexing technique.", nullable=True,
) help="Invalid indexing technique.",
parser.add_argument( )
"external_knowledge_api_id", .add_argument(
type=str, "external_knowledge_api_id",
nullable=True, type=str,
required=False, nullable=True,
) required=False,
parser.add_argument( )
"provider", .add_argument(
type=str, "provider",
nullable=True, type=str,
choices=Dataset.PROVIDER_LIST, nullable=True,
required=False, choices=Dataset.PROVIDER_LIST,
default="vendor", required=False,
) default="vendor",
parser.add_argument( )
"external_knowledge_id", .add_argument(
type=str, "external_knowledge_id",
nullable=True, type=str,
required=False, nullable=True,
required=False,
)
) )
args = parser.parse_args() args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@ -188,7 +260,7 @@ class DatasetListApi(Resource):
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
name=args["name"], name=args["name"],
description=args["description"], description=args["description"],
indexing_technique=args["indexing_technique"], indexing_technique=args["indexing_technique"],
@ -216,6 +288,7 @@ class DatasetApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -224,7 +297,7 @@ class DatasetApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
if dataset.embedding_model_provider: if dataset.embedding_model_provider:
provider_id = ModelProviderID(dataset.embedding_model_provider) provider_id = ModelProviderID(dataset.embedding_model_provider)
@ -235,7 +308,7 @@ class DatasetApi(Resource):
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@ -281,73 +354,76 @@ class DatasetApi(Resource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
help="type is required. Name must be between 1 to 40 characters.", nullable=False,
type=_validate_name, help="type is required. Name must be between 1 to 40 characters.",
) type=_validate_name,
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) )
parser.add_argument( .add_argument("description", location="json", store_missing=False, type=validate_description_length)
"indexing_technique", .add_argument(
type=str, "indexing_technique",
location="json", type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST, location="json",
nullable=True, choices=Dataset.INDEXING_TECHNIQUE_LIST,
help="Invalid indexing technique.", nullable=True,
) help="Invalid indexing technique.",
parser.add_argument( )
"permission", .add_argument(
type=str, "permission",
location="json", type=str,
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), location="json",
help="Invalid permission.", choices=(
) DatasetPermissionEnum.ONLY_ME,
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") DatasetPermissionEnum.ALL_TEAM,
parser.add_argument( DatasetPermissionEnum.PARTIAL_TEAM,
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." ),
) help="Invalid permission.",
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") )
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") .add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument(
parser.add_argument( "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
"external_retrieval_model", )
type=dict, .add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
required=False, .add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
nullable=True, .add_argument(
location="json", "external_retrieval_model",
help="Invalid external retrieval model.", type=dict,
) required=False,
nullable=True,
parser.add_argument( location="json",
"external_knowledge_id", help="Invalid external retrieval model.",
type=str, )
required=False, .add_argument(
nullable=True, "external_knowledge_id",
location="json", type=str,
help="Invalid external knowledge id.", required=False,
) nullable=True,
location="json",
parser.add_argument( help="Invalid external knowledge id.",
"external_knowledge_api_id", )
type=str, .add_argument(
required=False, "external_knowledge_api_id",
nullable=True, type=str,
location="json", required=False,
help="Invalid external knowledge api id.", nullable=True,
) location="json",
help="Invalid external knowledge api id.",
parser.add_argument( )
"icon_info", .add_argument(
type=dict, "icon_info",
required=False, type=dict,
nullable=True, required=False,
location="json", nullable=True,
help="Invalid icon info.", location="json",
help="Invalid icon info.",
)
) )
args = parser.parse_args() args = parser.parse_args()
data = request.get_json() data = request.get_json()
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting # check embedding model setting
if ( if (
@ -369,8 +445,8 @@ class DatasetApi(Resource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields) result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_user.current_tenant_id tenant_id = current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members": if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list( DatasetPermissionService.update_partial_member_list(
@ -394,9 +470,9 @@ class DatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id): def delete(self, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor if not (current_user.has_edit_permission or current_user.is_dataset_operator):
if not (current_user.is_editor or current_user.is_dataset_operator):
raise Forbidden() raise Forbidden()
try: try:
@ -435,6 +511,7 @@ class DatasetQueryApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -469,32 +546,31 @@ class DatasetIndexingEstimateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json") reqparse.RequestParser()
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") .add_argument("info_list", type=dict, required=True, nullable=True, location="json")
parser.add_argument( .add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
"indexing_technique", .add_argument(
type=str, "indexing_technique",
required=True, type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST, required=True,
nullable=True, choices=Dataset.INDEXING_TECHNIQUE_LIST,
location="json", nullable=True,
) location="json",
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") )
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json") .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument( .add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
"doc_language", type=str, default="English", required=False, nullable=False, location="json" .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
) )
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
extract_settings = [] extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file": if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"] file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = db.session.scalars( file_details = db.session.scalars(
select(UploadFile).where( select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
)
).all() ).all()
if file_details is None: if file_details is None:
@ -503,7 +579,7 @@ class DatasetIndexingEstimateApi(Resource):
if file_details: if file_details:
for file_detail in file_details: for file_detail in file_details:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, datasource_type=DatasourceType.FILE,
upload_file=file_detail, upload_file=file_detail,
document_model=args["doc_form"], document_model=args["doc_form"],
) )
@ -515,14 +591,16 @@ class DatasetIndexingEstimateApi(Resource):
credential_id = notion_info.get("credential_id") credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]: for page in notion_info["pages"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION,
notion_info={ notion_info=NotionInfo.model_validate(
"credential_id": credential_id, {
"notion_workspace_id": workspace_id, "credential_id": credential_id,
"notion_obj_id": page["page_id"], "notion_workspace_id": workspace_id,
"notion_page_type": page["type"], "notion_obj_id": page["page_id"],
"tenant_id": current_user.current_tenant_id, "notion_page_type": page["type"],
}, "tenant_id": current_tenant_id,
}
),
document_model=args["doc_form"], document_model=args["doc_form"],
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@ -530,15 +608,17 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"] website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]: for url in website_info_list["urls"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value, datasource_type=DatasourceType.WEBSITE,
website_info={ website_info=WebsiteInfo.model_validate(
"provider": website_info_list["provider"], {
"job_id": website_info_list["job_id"], "provider": website_info_list["provider"],
"url": url, "job_id": website_info_list["job_id"],
"tenant_id": current_user.current_tenant_id, "url": url,
"mode": "crawl", "tenant_id": current_tenant_id,
"only_main_content": website_info_list["only_main_content"], "mode": "crawl",
}, "only_main_content": website_info_list["only_main_content"],
}
),
document_model=args["doc_form"], document_model=args["doc_form"],
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@ -547,7 +627,7 @@ class DatasetIndexingEstimateApi(Resource):
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.indexing_estimate( response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_tenant_id,
extract_settings, extract_settings,
args["process_rule"], args["process_rule"],
args["doc_form"], args["doc_form"],
@ -578,6 +658,7 @@ class DatasetRelatedAppListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(related_app_list) @marshal_with(related_app_list)
def get(self, dataset_id): def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -609,11 +690,10 @@ class DatasetIndexingStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
documents = db.session.scalars( documents = db.session.scalars(
select(Document).where( select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
)
).all() ).all()
documents_status = [] documents_status = []
for document in documents: for document in documents:
@ -665,10 +745,9 @@ class DatasetApiKeyApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(api_key_list) @marshal_with(api_key_list)
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars( keys = db.session.scalars(
select(ApiToken).where( select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
)
).all() ).all()
return {"items": keys} return {"items": keys}
@ -678,17 +757,18 @@ class DatasetApiKeyApi(Resource):
@marshal_with(api_key_fields) @marshal_with(api_key_fields)
def post(self): def post(self):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
.count() .count()
) )
if current_key_count >= self.max_keys: if current_key_count >= self.max_keys:
flask_restx.abort( api.abort(
400, 400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.", message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded", code="max_keys_exceeded",
@ -696,7 +776,7 @@ class DatasetApiKeyApi(Resource):
key = ApiToken.generate_api_key(self.token_prefix, 24) key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken() api_token = ApiToken()
api_token.tenant_id = current_user.current_tenant_id api_token.tenant_id = current_tenant_id
api_token.token = key api_token.token = key
api_token.type = self.resource_type api_token.type = self.resource_type
db.session.add(api_token) db.session.add(api_token)
@ -716,6 +796,7 @@ class DatasetApiDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, api_key_id): def delete(self, api_key_id):
current_user, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id) api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
@ -725,7 +806,7 @@ class DatasetApiDeleteApi(Resource):
key = ( key = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where( .where(
ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id, ApiToken.id == api_key_id,
) )
@ -733,7 +814,7 @@ class DatasetApiDeleteApi(Resource):
) )
if key is None: if key is None:
flask_restx.abort(404, message="API key not found") api.abort(404, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit() db.session.commit()
@ -776,49 +857,7 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required @account_initialization_required
def get(self): def get(self):
vector_type = dify_config.VECTOR_STORE vector_type = dify_config.VECTOR_STORE
match vector_type: return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
case (
VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR
| VectorType.VASTBASE
| VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM
| VectorType.COUCHBASE
| VectorType.MILVUS
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
| VectorType.HUAWEI_CLOUD
| VectorType.TENCENT
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]
}
case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>") @console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
@ -831,48 +870,7 @@ class DatasetRetrievalSettingMockApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, vector_type): def get(self, vector_type):
match vector_type: return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
case (
VectorType.MILVUS
| VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE
| VectorType.PGVECTOR
| VectorType.VASTBASE
| VectorType.LINDORM
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
| VectorType.TENCENT
| VectorType.HUAWEI_CLOUD
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]
}
case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs") @console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
@ -907,6 +905,7 @@ class DatasetPermissionUserListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:

View File

@ -6,7 +6,6 @@ from typing import Literal, cast
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc, select from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -44,7 +43,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from extensions.ext_database import db from extensions.ext_database import db
from fields.document_fields import ( from fields.document_fields import (
dataset_and_document_fields, dataset_and_document_fields,
@ -53,7 +52,7 @@ from fields.document_fields import (
document_with_segments_fields, document_with_segments_fields,
) )
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
@ -64,6 +63,7 @@ logger = logging.getLogger(__name__)
class DocumentResource(Resource): class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document: def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -78,12 +78,13 @@ class DocumentResource(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id: if document.tenant_id != current_tenant_id:
raise Forbidden("No permission.") raise Forbidden("No permission.")
return document return document
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]: def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
current_user, _ = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -111,6 +112,7 @@ class GetProcessRuleApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
current_user, _ = current_account_with_tenant()
req_data = request.args req_data = request.args
document_id = req_data.get("document_id") document_id = req_data.get("document_id")
@ -167,6 +169,7 @@ class DatasetDocumentListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
@ -198,7 +201,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
if search: if search:
search = f"%{search}%" search = f"%{search}%"
@ -272,6 +275,7 @@ class DatasetDocumentListApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -288,23 +292,23 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" .add_argument(
) "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
parser.add_argument("data_source", type=dict, required=False, location="json") )
parser.add_argument("process_rule", type=dict, required=False, location="json") .add_argument("data_source", type=dict, required=False, location="json")
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") .add_argument("process_rule", type=dict, required=False, location="json")
parser.add_argument("original_document_id", type=str, required=False, location="json") .add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") .add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") .add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument( .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
"doc_language", type=str, default="English", required=False, nullable=False, location="json" .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
) )
args = parser.parse_args() args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
if not dataset.indexing_technique and not knowledge_config.indexing_technique: if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
@ -371,37 +375,38 @@ class DatasetInitApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"indexing_technique", .add_argument(
type=str, "indexing_technique",
choices=Dataset.INDEXING_TECHNIQUE_LIST, type=str,
required=True, choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=False, required=True,
location="json", nullable=False,
location="json",
)
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
) )
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.") raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=args["embedding_model_provider"], provider=args["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=args["embedding_model"], model=args["embedding_model"],
@ -418,7 +423,9 @@ class DatasetInitApi(Resource):
try: try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id( dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user tenant_id=current_tenant_id,
knowledge_config=knowledge_config,
account=current_user,
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -444,6 +451,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id, document_id): def get(self, dataset_id, document_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
@ -452,7 +460,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule data_process_rule = document.dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict() data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
@ -472,14 +480,14 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.") raise NotFound("File not found.")
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form
) )
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
estimate_response = indexing_runner.indexing_estimate( estimate_response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_tenant_id,
[extract_setting], [extract_setting],
data_process_rule_dict, data_process_rule_dict,
document.doc_form, document.doc_form,
@ -508,13 +516,14 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id, batch): def get(self, dataset_id, batch):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
batch = str(batch) batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch) documents = self.get_batch_documents(dataset_id, batch)
if not documents: if not documents:
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200 return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
data_process_rule = documents[0].dataset_process_rule data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict() data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
extract_settings = [] extract_settings = []
for document in documents: for document in documents:
if document.indexing_status in {"completed", "error"}: if document.indexing_status in {"completed", "error"}:
@ -527,7 +536,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first() .first()
) )
@ -535,7 +544,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.") raise NotFound("File not found.")
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@ -543,14 +552,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info: if not data_source_info:
continue continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION,
notion_info={ notion_info=NotionInfo.model_validate(
"credential_id": data_source_info["credential_id"], {
"notion_workspace_id": data_source_info["notion_workspace_id"], "credential_id": data_source_info["credential_id"],
"notion_obj_id": data_source_info["notion_page_id"], "notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_page_type": data_source_info["type"], "notion_obj_id": data_source_info["notion_page_id"],
"tenant_id": current_user.current_tenant_id, "notion_page_type": data_source_info["type"],
}, "tenant_id": current_tenant_id,
}
),
document_model=document.doc_form, document_model=document.doc_form,
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@ -558,15 +569,17 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info: if not data_source_info:
continue continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value, datasource_type=DatasourceType.WEBSITE,
website_info={ website_info=WebsiteInfo.model_validate(
"provider": data_source_info["provider"], {
"job_id": data_source_info["job_id"], "provider": data_source_info["provider"],
"url": data_source_info["url"], "job_id": data_source_info["job_id"],
"tenant_id": current_user.current_tenant_id, "url": data_source_info["url"],
"mode": data_source_info["mode"], "tenant_id": current_tenant_id,
"only_main_content": data_source_info["only_main_content"], "mode": data_source_info["mode"],
}, "only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form, document_model=document.doc_form,
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@ -576,7 +589,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.indexing_estimate( response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_tenant_id,
extract_settings, extract_settings,
data_process_rule_dict, data_process_rule_dict,
document.doc_form, document.doc_form,
@ -733,7 +746,7 @@ class DocumentApi(DocumentResource):
"name": document.name, "name": document.name,
"created_from": document.created_from, "created_from": document.created_from,
"created_by": document.created_by, "created_by": document.created_by,
"created_at": document.created_at.timestamp(), "created_at": int(document.created_at.timestamp()),
"tokens": document.tokens, "tokens": document.tokens,
"indexing_status": document.indexing_status, "indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None, "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
@ -753,7 +766,7 @@ class DocumentApi(DocumentResource):
} }
else: else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict data_source_info = document.data_source_detail_dict
response = { response = {
"id": document.id, "id": document.id,
@ -766,7 +779,7 @@ class DocumentApi(DocumentResource):
"name": document.name, "name": document.name,
"created_from": document.created_from, "created_from": document.created_from,
"created_by": document.created_by, "created_by": document.created_by,
"created_at": document.created_at.timestamp(), "created_at": int(document.created_at.timestamp()),
"tokens": document.tokens, "tokens": document.tokens,
"indexing_status": document.indexing_status, "indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None, "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
@ -827,6 +840,7 @@ class DocumentProcessingApi(DocumentResource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]): def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
@ -877,6 +891,7 @@ class DocumentMetadataApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self, dataset_id, document_id): def put(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
@ -924,6 +939,7 @@ class DocumentStatusApi(DocumentResource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if dataset is None: if dataset is None:
@ -1027,8 +1043,9 @@ class DocumentRetryApi(DocumentResource):
def post(self, dataset_id): def post(self, dataset_id):
"""retry document.""" """retry document."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json") "document_ids", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -1070,12 +1087,14 @@ class DocumentRenameApi(DocumentResource):
@marshal_with(document_fields) @marshal_with(document_fields)
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(current_user, dataset) DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -1093,6 +1112,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
@account_initialization_required @account_initialization_required
def get(self, dataset_id, document_id): def get(self, dataset_id, document_id):
"""sync website document.""" """sync website document."""
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
@ -1101,7 +1121,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
document = DocumentService.get_document(dataset.id, document_id) document = DocumentService.get_document(dataset.id, document_id)
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id: if document.tenant_id != current_tenant_id:
raise Forbidden("No permission.") raise Forbidden("No permission.")
if document.data_source_type != "website_crawl": if document.data_source_type != "website_crawl":
raise ValueError("Document is not a website document.") raise ValueError("Document is not a website document.")

View File

@ -1,7 +1,6 @@
import uuid import uuid
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse from flask_restx import Resource, marshal, reqparse
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id, document_id): def get(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -59,13 +60,15 @@ class DatasetDocumentSegmentListApi(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = reqparse.RequestParser() parser = (
parser.add_argument("limit", type=int, default=20, location="args") reqparse.RequestParser()
parser.add_argument("status", type=str, action="append", default=[], location="args") .add_argument("limit", type=int, default=20, location="args")
parser.add_argument("hit_count_gte", type=int, default=None, location="args") .add_argument("status", type=str, action="append", default=[], location="args")
parser.add_argument("enabled", type=str, default="all", location="args") .add_argument("hit_count_gte", type=int, default=None, location="args")
parser.add_argument("keyword", type=str, default=None, location="args") .add_argument("enabled", type=str, default="all", location="args")
parser.add_argument("page", type=int, default=1, location="args") .add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
args = parser.parse_args() args = parser.parse_args()
@ -79,7 +82,7 @@ class DatasetDocumentSegmentListApi(Resource):
select(DocumentSegment) select(DocumentSegment)
.where( .where(
DocumentSegment.document_id == str(document_id), DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id, DocumentSegment.tenant_id == current_tenant_id,
) )
.order_by(DocumentSegment.position.asc()) .order_by(DocumentSegment.position.asc())
) )
@ -115,6 +118,8 @@ class DatasetDocumentSegmentListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id): def delete(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -148,6 +153,8 @@ class DatasetDocumentSegmentApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action): def patch(self, dataset_id, document_id, action):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
@ -171,7 +178,7 @@ class DatasetDocumentSegmentApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -204,6 +211,8 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -221,7 +230,7 @@ class DatasetDocumentSegmentAddApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -237,10 +246,12 @@ class DatasetDocumentSegmentAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = (
parser.add_argument("content", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("answer", type=str, required=False, nullable=True, location="json") .add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") .add_argument("answer", type=str, required=False, nullable=True, location="json")
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset) segment = SegmentService.create_segment(args, document, dataset)
@ -255,6 +266,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id): def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -272,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -287,7 +300,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -300,16 +313,18 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = (
parser.add_argument("content", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("answer", type=str, required=False, nullable=True, location="json") .add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") .add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument( .add_argument("keywords", type=list, required=False, nullable=True, location="json")
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json" .add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
) )
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset) segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required @setup_required
@ -317,6 +332,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id): def delete(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -333,7 +350,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -361,6 +378,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -372,8 +391,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json") "upload_file_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
upload_file_id = args["upload_file_id"] upload_file_id = args["upload_file_id"]
@ -392,7 +412,12 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
# send batch add segments task # send batch add segments task
redis_client.setnx(indexing_cache_key, "waiting") redis_client.setnx(indexing_cache_key, "waiting")
batch_create_segment_to_index_task.delay( batch_create_segment_to_index_task.delay(
str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id str(job_id),
upload_file_id,
dataset_id,
document_id,
current_tenant_id,
current_user.id,
) )
except Exception as e: except Exception as e:
return {"error": str(e)}, 500 return {"error": str(e)}, 500
@ -422,6 +447,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id, segment_id): def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -436,7 +463,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -448,7 +475,7 @@ class ChildChunkAddApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -464,11 +491,13 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("content", type=str, required=True, nullable=False, location="json") "content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) content = args["content"]
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@ -477,6 +506,8 @@ class ChildChunkAddApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id, document_id, segment_id): def get(self, dataset_id, document_id, segment_id):
_, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -493,15 +524,17 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
parser = reqparse.RequestParser() parser = (
parser.add_argument("limit", type=int, default=20, location="args") reqparse.RequestParser()
parser.add_argument("keyword", type=str, default=None, location="args") .add_argument("limit", type=int, default=20, location="args")
parser.add_argument("page", type=int, default=1, location="args") .add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
args = parser.parse_args() args = parser.parse_args()
@ -524,6 +557,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id): def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -540,7 +575,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -553,11 +588,13 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") "chunks", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")] chunks_data = args["chunks"]
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
@ -573,6 +610,8 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id, child_chunk_id): def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -589,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -600,7 +639,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk) db.session.query(ChildChunk)
.where( .where(
ChildChunk.id == str(child_chunk_id), ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id, ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id, ChildChunk.document_id == document_id,
) )
@ -627,6 +666,8 @@ class ChildChunkUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id): def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -643,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -654,7 +695,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk) db.session.query(ChildChunk)
.where( .where(
ChildChunk.id == str(child_chunk_id), ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id, ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id, ChildChunk.document_id == document_id,
) )
@ -670,13 +711,13 @@ class ChildChunkUpdateApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("content", type=str, required=True, nullable=False, location="json") "content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
child_chunk = SegmentService.update_child_chunk( content = args["content"]
args.get("content"), child_chunk, segment, document, dataset child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200

View File

@ -1,5 +1,4 @@
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, reqparse from flask_restx import Resource, fields, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -8,14 +7,14 @@ from controllers.console import api, console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService from services.knowledge_service import ExternalDatasetTestService
def _validate_name(name): def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 100: if not name or len(name) < 1 or len(name) > 100:
raise ValueError("Name must be between 1 to 100 characters.") raise ValueError("Name must be between 1 to 100 characters.")
return name return name
@ -37,12 +36,13 @@ class ExternalApiTemplateListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str) search = request.args.get("keyword", default=None, type=str)
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis( external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
page, limit, current_user.current_tenant_id, search page, limit, current_tenant_id, search
) )
response = { response = {
"data": [item.to_dict() for item in external_knowledge_apis], "data": [item.to_dict() for item in external_knowledge_apis],
@ -57,20 +57,23 @@ class ExternalApiTemplateListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() current_user, current_tenant_id = current_account_with_tenant()
parser.add_argument( parser = (
"name", reqparse.RequestParser()
nullable=False, .add_argument(
required=True, "name",
help="Name is required. Name must be between 1 to 100 characters.", nullable=False,
type=_validate_name, required=True,
) help="Name is required. Name must be between 1 to 100 characters.",
parser.add_argument( type=_validate_name,
"settings", )
type=dict, .add_argument(
location="json", "settings",
nullable=False, type=dict,
required=True, location="json",
nullable=False,
required=True,
)
) )
args = parser.parse_args() args = parser.parse_args()
@ -82,7 +85,7 @@ class ExternalApiTemplateListApi(Resource):
try: try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api( external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args tenant_id=current_tenant_id, user_id=current_user.id, args=args
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@ -112,28 +115,31 @@ class ExternalApiTemplateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, external_knowledge_api_id): def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_id = str(external_knowledge_api_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="type is required. Name must be between 1 to 100 characters.", required=True,
type=_validate_name, help="type is required. Name must be between 1 to 100 characters.",
) type=_validate_name,
parser.add_argument( )
"settings", .add_argument(
type=dict, "settings",
location="json", type=dict,
nullable=False, location="json",
required=True, nullable=False,
required=True,
)
) )
args = parser.parse_args() args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"]) ExternalDatasetService.validate_api_list(args["settings"])
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api( external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
user_id=current_user.id, user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id, external_knowledge_api_id=external_knowledge_api_id,
args=args, args=args,
@ -145,13 +151,13 @@ class ExternalApiTemplateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, external_knowledge_api_id): def delete(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_id = str(external_knowledge_api_id)
# The role of the current user in the ta table must be admin, owner, or editor if not (current_user.has_edit_permission or current_user.is_dataset_operator):
if not (current_user.is_editor or current_user.is_dataset_operator):
raise Forbidden() raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id) ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
return {"result": "success"}, 204 return {"result": "success"}, 204
@ -196,21 +202,24 @@ class ExternalDatasetCreateApi(Resource):
@account_initialization_required @account_initialization_required
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json") .add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
parser.add_argument( .add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="name is required. Name must be between 1 to 100 characters.", required=True,
type=_validate_name, help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument("description", type=str, required=False, nullable=True, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
) )
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -220,7 +229,7 @@ class ExternalDatasetCreateApi(Resource):
try: try:
dataset = ExternalDatasetService.create_external_dataset( dataset = ExternalDatasetService.create_external_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
user_id=current_user.id, user_id=current_user.id,
args=args, args=args,
) )
@ -252,6 +261,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -262,10 +272,12 @@ class ExternalKnowledgeHitTestingApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
parser = reqparse.RequestParser() parser = (
parser.add_argument("query", type=str, location="json") reqparse.RequestParser()
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") .add_argument("query", type=str, location="json")
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json") .add_argument("external_retrieval_model", type=dict, required=False, location="json")
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
HitTestingService.hit_testing_args_check(args) HitTestingService.hit_testing_args_check(args)
@ -301,15 +313,17 @@ class BedrockRetrievalApi(Resource):
) )
@api.response(200, "Bedrock retrieval test completed") @api.response(200, "Bedrock retrieval test completed")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
"query", .add_argument(
nullable=False, "query",
required=True, nullable=False,
type=str, required=True,
type=str,
)
.add_argument("knowledge_id", nullable=False, required=True, type=str)
) )
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
args = parser.parse_args() args = parser.parse_args()
# Call the knowledge retrieval service # Call the knowledge retrieval service

View File

@ -1,10 +1,9 @@
import logging import logging
from flask_login import current_user
from flask_restx import marshal, reqparse from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services.dataset_service import services
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
@ -20,6 +19,8 @@ from core.errors.error import (
) )
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields from fields.hit_testing_fields import hit_testing_record_fields
from libs.login import current_user
from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
@ -29,6 +30,7 @@ logger = logging.getLogger(__name__)
class DatasetsHitTestingBase: class DatasetsHitTestingBase:
@staticmethod @staticmethod
def get_and_validate_dataset(dataset_id: str): def get_and_validate_dataset(dataset_id: str):
assert isinstance(current_user, Account)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -46,15 +48,17 @@ class DatasetsHitTestingBase:
@staticmethod @staticmethod
def parse_args(): def parse_args():
parser = reqparse.RequestParser() parser = (
reqparse.RequestParser()
parser.add_argument("query", type=str, location="json") .add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json") .add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") .add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
return parser.parse_args() return parser.parse_args()
@staticmethod @staticmethod
def perform_hit_testing(dataset, args): def perform_hit_testing(dataset, args):
assert isinstance(current_user, Account)
try: try:
response = HitTestingService.retrieve( response = HitTestingService.retrieve(
dataset=dataset, dataset=dataset,

View File

@ -1,13 +1,12 @@
from typing import Literal from typing import Literal
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields from fields.dataset_fields import dataset_metadata_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import ( from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs, MetadataArgs,
@ -24,11 +23,14 @@ class DatasetMetadataCreateApi(Resource):
@enterprise_license_required @enterprise_license_required
@marshal_with(dataset_metadata_fields) @marshal_with(dataset_metadata_fields)
def post(self, dataset_id): def post(self, dataset_id):
parser = reqparse.RequestParser() current_user, _ = current_account_with_tenant()
parser.add_argument("type", type=str, required=True, nullable=False, location="json") parser = (
parser.add_argument("name", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
metadata_args = MetadataArgs(**args) metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -59,9 +61,10 @@ class DatasetMetadataApi(Resource):
@enterprise_license_required @enterprise_license_required
@marshal_with(dataset_metadata_fields) @marshal_with(dataset_metadata_fields)
def patch(self, dataset_id, metadata_id): def patch(self, dataset_id, metadata_id):
parser = reqparse.RequestParser() current_user, _ = current_account_with_tenant()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
name = args["name"]
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id) metadata_id_str = str(metadata_id)
@ -70,7 +73,7 @@ class DatasetMetadataApi(Resource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
return metadata, 200 return metadata, 200
@setup_required @setup_required
@ -78,6 +81,7 @@ class DatasetMetadataApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def delete(self, dataset_id, metadata_id): def delete(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id) metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -107,6 +111,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def post(self, dataset_id, action: Literal["enable", "disable"]): def post(self, dataset_id, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -127,16 +132,18 @@ class DocumentMetadataEditApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") "operation_data", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
metadata_args = MetadataOperationData(**args) metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args) MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -1,19 +1,15 @@
from fastapi.encoders import jsonable_encoder
from flask import make_response, redirect, request from flask import make_response, redirect, request
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config from configs import dify_config
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.wraps import ( from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
account_initialization_required,
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen from libs.helper import StrLen
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService from services.plugin.oauth_service import OAuthProxyService
@ -24,11 +20,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, provider_id: str): def get(self, provider_id: str):
user = current_user current_user, current_tenant_id = current_account_with_tenant()
tenant_id = user.current_tenant_id
if not current_user.is_editor: tenant_id = current_tenant_id
raise Forbidden()
credential_id = request.args.get("credential_id") credential_id = request.args.get("credential_id")
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
@ -52,7 +48,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback" redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
authorization_url_response = oauth_handler.get_authorization_url( authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user.id, user_id=current_user.id,
plugin_id=plugin_id, plugin_id=plugin_id,
provider=provider_name, provider=provider_name,
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
@ -125,27 +121,30 @@ class DatasourceOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_datasource = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>") @console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource): class DatasourceAuth(Resource):
@api.expect(parser_datasource)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() args = parser_datasource.parse_args()
parser.add_argument(
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
)
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
try: try:
datasource_provider_service.add_datasource_api_key_provider( datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider_id=datasource_provider_id, provider_id=datasource_provider_id,
credentials=args["credentials"], credentials=args["credentials"],
name=args["name"], name=args["name"],
@ -160,31 +159,39 @@ class DatasourceAuth(Resource):
def get(self, provider_id: str): def get(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
_, current_tenant_id = current_account_with_tenant()
datasources = datasource_provider_service.list_datasource_credentials( datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=datasource_provider_id.provider_name, provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id, plugin_id=datasource_provider_id.plugin_id,
) )
return {"result": datasources}, 200 return {"result": datasources}, 200
parser_datasource_delete = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource): class DatasourceAuthDeleteApi(Resource):
@api.expect(parser_datasource_delete)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name provider_name = datasource_provider_id.provider_name
if not current_user.is_editor:
raise Forbidden() args = parser_datasource_delete.parse_args()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials( datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
auth_id=args["credential_id"], auth_id=args["credential_id"],
provider=provider_name, provider=provider_name,
plugin_id=plugin_id, plugin_id=plugin_id,
@ -192,23 +199,30 @@ class DatasourceAuthDeleteApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
parser_datasource_update = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource): class DatasourceAuthUpdateApi(Resource):
@api.expect(parser_datasource_update)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser() args = parser_datasource_update.parse_args()
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials( datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
auth_id=args["credential_id"], auth_id=args["credential_id"],
provider=datasource_provider_id.provider_name, provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id, plugin_id=datasource_provider_id.plugin_id,
@ -224,10 +238,10 @@ class DatasourceAuthListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_all_datasource_credentials( datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
tenant_id=current_user.current_tenant_id
)
return {"result": jsonable_encoder(datasources)}, 200 return {"result": jsonable_encoder(datasources)}, 200
@ -237,29 +251,35 @@ class DatasourceHardCodeAuthListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_hard_code_datasource_credentials( datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
tenant_id=current_user.current_tenant_id
)
return {"result": jsonable_encoder(datasources)}, 200 return {"result": jsonable_encoder(datasources)}, 200
parser_datasource_custom = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource): class DatasourceAuthOauthCustomClient(Resource):
@api.expect(parser_datasource_custom)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() args = parser_datasource_custom.parse_args()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params( datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}), client_params=args.get("client_params", {}),
enabled=args.get("enable_oauth_custom_client", False), enabled=args.get("enable_oauth_custom_client", False),
@ -270,52 +290,63 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider_id: str): def delete(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params( datasource_provider_service.remove_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource): class DatasourceAuthDefaultApi(Resource):
@api.expect(parser_default)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() args = parser_default.parse_args()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider( datasource_provider_service.set_default_datasource_provider(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
credential_id=args["id"], credential_id=args["id"],
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
parser_update_name = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource): class DatasourceUpdateProviderNameApi(Resource):
@api.expect(parser_update_name)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() args = parser_update_name.parse_args()
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name( datasource_provider_service.update_datasource_provider_name(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
name=args["name"], name=args["name"],
credential_id=args["credential_id"], credential_id=args["credential_id"],

View File

@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore
) )
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required from libs.login import current_user, login_required
@ -12,9 +12,17 @@ from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline import RagPipelineService
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource): class DataSourceContentPreviewApi(Resource):
@api.expect(parser)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -26,10 +34,6 @@ class DataSourceContentPreviewApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
inputs = args.get("inputs") inputs = args.get("inputs")

View File

@ -20,13 +20,13 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _validate_name(name): def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40: if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.") raise ValueError("Name must be between 1 to 40 characters.")
return name return name
def _validate_description_length(description): def _validate_description_length(description: str) -> str:
if len(description) > 400: if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.") raise ValueError("Description cannot exceed 400 characters.")
return description return description
@ -66,29 +66,31 @@ class CustomizedPipelineTemplateApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def patch(self, template_id: str): def patch(self, template_id: str):
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="Name must be between 1 to 40 characters.", required=True,
type=_validate_name, help="Name must be between 1 to 40 characters.",
) type=_validate_name,
parser.add_argument( )
"description", .add_argument(
type=str, "description",
nullable=True, type=_validate_description_length,
required=False, nullable=True,
default="", required=False,
) default="",
parser.add_argument( )
"icon_info", .add_argument(
type=dict, "icon_info",
location="json", type=dict,
nullable=True, location="json",
nullable=True,
)
) )
args = parser.parse_args() args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args) pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200 return 200
@ -123,26 +125,28 @@ class PublishCustomizedPipelineTemplateApi(Resource):
@enterprise_license_required @enterprise_license_required
@knowledge_pipeline_publish_enabled @knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str): def post(self, pipeline_id: str):
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="Name must be between 1 to 40 characters.", required=True,
type=_validate_name, help="Name must be between 1 to 40 characters.",
) type=_validate_name,
parser.add_argument( )
"description", .add_argument(
type=str, "description",
nullable=True, type=_validate_description_length,
required=False, nullable=True,
default="", required=False,
) default="",
parser.add_argument( )
"icon_info", .add_argument(
type=dict, "icon_info",
location="json", type=dict,
nullable=True, location="json",
nullable=True,
)
) )
args = parser.parse_args() args = parser.parse_args()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()

View File

@ -1,5 +1,4 @@
from flask_login import current_user # type: ignore # type: ignore from flask_restx import Resource, marshal, reqparse
from flask_restx import Resource, marshal, reqparse # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -13,25 +12,13 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/rag/pipeline/dataset") @console_ns.route("/rag/pipeline/dataset")
class CreateRagPipelineDatasetApi(Resource): class CreateRagPipelineDatasetApi(Resource):
@setup_required @setup_required
@ -39,9 +26,7 @@ class CreateRagPipelineDatasetApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"yaml_content", "yaml_content",
type=str, type=str,
nullable=False, nullable=False,
@ -50,7 +35,7 @@ class CreateRagPipelineDatasetApi(Resource):
) )
args = parser.parse_args() args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
@ -70,12 +55,12 @@ class CreateRagPipelineDatasetApi(Resource):
with Session(db.engine) as session: with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session) rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset( import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
) )
if rag_pipeline_dataset_create_entity.permission == "partial_members": if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list( DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id, current_tenant_id,
import_info["dataset_id"], import_info["dataset_id"],
rag_pipeline_dataset_create_entity.partial_member_list, rag_pipeline_dataset_create_entity.partial_member_list,
) )
@ -93,10 +78,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset( dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity( rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
name="", name="",
description="", description="",

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Any, NoReturn from typing import NoReturn
from flask import Response from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
@ -11,21 +11,19 @@ from controllers.console.app.error import (
DraftWorkflowNotExist, DraftWorkflowNotExist,
) )
from controllers.console.app.workflow_draft_variable import ( from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS, _WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
) )
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models.account import Account from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from models.workflow import WorkflowDraftVariable from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -34,43 +32,19 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
value = variable.get_value()
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
# Refresh the url signature before returning it to client.
if isinstance(value, FileSegment):
file = value.value
file.remote_url = file.generate_url()
elif isinstance(value, ArrayFileSegment):
files = value.value
for file in files:
file.remote_url = file.generate_url()
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser(): def _create_pagination_parser():
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"page", .add_argument(
type=inputs.int_range(1, 100_000), "page",
required=False, type=inputs.int_range(1, 100_000),
default=1, required=False,
location="args", default=1,
help="the page of data requested", location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
) )
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser return parser
@ -104,7 +78,7 @@ def _api_prerequisite(f):
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not isinstance(current_user, Account) or not current_user.is_editor: if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
return f(*args, **kwargs) return f(*args, **kwargs)
@ -234,10 +208,11 @@ class RagPipelineVariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# } # }
parser = reqparse.RequestParser() parser = (
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") reqparse.RequestParser()
# Parse 'value' field as-is to maintain its original data structure .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=db.session(), session=db.session(),

View File

@ -1,6 +1,3 @@
from typing import cast
from flask_login import current_user # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -13,8 +10,7 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@ -28,26 +24,29 @@ class RagPipelineImportApi(Resource):
@marshal_with(pipeline_import_fields) @marshal_with(pipeline_import_fields)
def post(self): def post(self):
# Check user role first # Check user role first
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("mode", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("yaml_content", type=str, location="json") .add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_url", type=str, location="json") .add_argument("yaml_content", type=str, location="json")
parser.add_argument("name", type=str, location="json") .add_argument("yaml_url", type=str, location="json")
parser.add_argument("description", type=str, location="json") .add_argument("name", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("description", type=str, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("icon_type", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json") .add_argument("icon", type=str, location="json")
parser.add_argument("pipeline_id", type=str, location="json") .add_argument("icon_background", type=str, location="json")
.add_argument("pipeline_id", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = RagPipelineDslService(session) import_service = RagPipelineDslService(session)
# Import app # Import app
account = cast(Account, current_user) account = current_user
result = import_service.import_rag_pipeline( result = import_service.import_rag_pipeline(
account=account, account=account,
import_mode=args["mode"], import_mode=args["mode"],
@ -60,9 +59,9 @@ class RagPipelineImportApi(Resource):
# Return appropriate status code based on result # Return appropriate status code based on result
status = result.status status = result.status
if status == ImportStatus.FAILED.value: if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value: elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202 return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@ -74,20 +73,21 @@ class RagPipelineImportConfirmApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(pipeline_import_fields) @marshal_with(pipeline_import_fields)
def post(self, import_id): def post(self, import_id):
current_user, _ = current_account_with_tenant()
# Check user role first # Check user role first
if not current_user.is_editor: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = RagPipelineDslService(session) import_service = RagPipelineDslService(session)
# Confirm import # Confirm import
account = cast(Account, current_user) account = current_user
result = import_service.confirm_import(import_id=import_id, account=account) result = import_service.confirm_import(import_id=import_id, account=account)
session.commit() session.commit()
# Return appropriate status code based on result # Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value: if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@ -100,7 +100,8 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(pipeline_import_check_dependencies_fields) @marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
with Session(db.engine) as session: with Session(db.engine) as session:
@ -117,12 +118,12 @@ class RagPipelineExportApi(Resource):
@get_rag_pipeline @get_rag_pipeline
@account_initialization_required @account_initialization_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
# Add include_secret params # Add include_secret params
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
parser.add_argument("include_secret", type=str, default="false", location="args")
args = parser.parse_args() args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
ConversationCompletedError, ConversationCompletedError,
DraftWorkflowNotExist, DraftWorkflowNotExist,
@ -18,6 +18,7 @@ from controllers.console.app.error import (
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
edit_permission_required,
setup_required, setup_required,
) )
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
@ -36,8 +37,8 @@ from fields.workflow_run_fields import (
) )
from libs import helper from libs import helper
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, current_user, login_required
from models.account import Account from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from models.model import EndUser from models.model import EndUser
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
@ -56,15 +57,12 @@ class DraftRagPipelineApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get draft rag pipeline's workflow Get draft rag pipeline's workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model # fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
@ -79,23 +77,25 @@ class DraftRagPipelineApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline): def post(self, pipeline: Pipeline):
""" """
Sync draft workflow Sync draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
raise Forbidden()
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type: if "application/json" in content_type:
parser = reqparse.RequestParser() parser = (
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("hash", type=str, required=False, location="json") .add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("environment_variables", type=list, required=False, location="json") .add_argument("hash", type=str, required=False, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json") .add_argument("environment_variables", type=list, required=False, location="json")
parser.add_argument("rag_pipeline_variables", type=list, required=False, location="json") .add_argument("conversation_variables", type=list, required=False, location="json")
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
@ -148,23 +148,25 @@ class DraftRagPipelineApi(Resource):
} }
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource): class RagPipelineDraftRunIterationNodeApi(Resource):
@api.expect(parser_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline, node_id: str): def post(self, pipeline: Pipeline, node_id: str):
""" """
Run draft workflow iteration node Run draft workflow iteration node
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() args = parser_run.parse_args()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try: try:
response = PipelineGenerateService.generate_single_iteration( response = PipelineGenerateService.generate_single_iteration(
@ -185,6 +187,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource): class RagPipelineDraftRunLoopNodeApi(Resource):
@api.expect(parser_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -194,12 +197,11 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
Run draft workflow loop node Run draft workflow loop node
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() args = parser_run.parse_args()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try: try:
response = PipelineGenerateService.generate_single_loop( response = PipelineGenerateService.generate_single_loop(
@ -218,8 +220,18 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
raise InternalServerError() raise InternalServerError()
parser_draft_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource): class DraftRagPipelineRunApi(Resource):
@api.expect(parser_draft_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -229,15 +241,11 @@ class DraftRagPipelineRunApi(Resource):
Run draft workflow Run draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() args = parser_draft_run.parse_args()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
args = parser.parse_args()
try: try:
response = PipelineGenerateService.generate( response = PipelineGenerateService.generate(
@ -253,8 +261,21 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description) raise InvokeRateLimitHttpError(ex.description)
parser_published_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource): class PublishedRagPipelineRunApi(Resource):
@api.expect(parser_published_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -264,18 +285,11 @@ class PublishedRagPipelineRunApi(Resource):
Run published workflow Run published workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() args = parser_published_run.parse_args()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False)
parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
parser.add_argument("original_document_id", type=str, required=False, location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
@ -303,15 +317,16 @@ class PublishedRagPipelineRunApi(Resource):
# Run rag pipeline datasource # Run rag pipeline datasource
# """ # """
# # The role of the current user in the ta table must be admin, owner, or editor # # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.is_editor: # if not current_user.has_edit_permission:
# raise Forbidden() # raise Forbidden()
# #
# if not isinstance(current_user, Account): # if not isinstance(current_user, Account):
# raise Forbidden() # raise Forbidden()
# #
# parser = reqparse.RequestParser() # parser = (reqparse.RequestParser()
# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") # .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# parser.add_argument("datasource_type", type=str, required=True, location="json") # .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args() # args = parser.parse_args()
# #
# job_id = args.get("job_id") # job_id = args.get("job_id")
@ -344,15 +359,16 @@ class PublishedRagPipelineRunApi(Resource):
# Run rag pipeline datasource # Run rag pipeline datasource
# """ # """
# # The role of the current user in the ta table must be admin, owner, or editor # # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.is_editor: # if not current_user.has_edit_permission:
# raise Forbidden() # raise Forbidden()
# #
# if not isinstance(current_user, Account): # if not isinstance(current_user, Account):
# raise Forbidden() # raise Forbidden()
# #
# parser = reqparse.RequestParser() # parser = (reqparse.RequestParser()
# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") # .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# parser.add_argument("datasource_type", type=str, required=True, location="json") # .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args() # args = parser.parse_args()
# #
# job_id = args.get("job_id") # job_id = args.get("job_id")
@ -374,8 +390,17 @@ class PublishedRagPipelineRunApi(Resource):
# #
# return result # return result
# #
parser_rag_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource): class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@api.expect(parser_rag_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -385,14 +410,11 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
Run rag pipeline datasource Run rag pipeline datasource
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() args = parser_rag_run.parse_args()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
args = parser.parse_args()
inputs = args.get("inputs") inputs = args.get("inputs")
if inputs is None: if inputs is None:
@ -419,6 +441,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource): class RagPipelineDraftDatasourceNodeRunApi(Resource):
@api.expect(parser_rag_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -428,14 +451,11 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
Run rag pipeline datasource Run rag pipeline datasource
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() args = parser_rag_run.parse_args()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
args = parser.parse_args()
inputs = args.get("inputs") inputs = args.get("inputs")
if inputs is None: if inputs is None:
@ -460,8 +480,14 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
) )
parser_run_api = reqparse.RequestParser().add_argument(
"inputs", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource): class RagPipelineDraftNodeRunApi(Resource):
@api.expect(parser_run_api)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -472,12 +498,11 @@ class RagPipelineDraftNodeRunApi(Resource):
Run draft workflow node Run draft workflow node
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() args = parser_run_api.parse_args()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
inputs = args.get("inputs") inputs = args.get("inputs")
if inputs == None: if inputs == None:
@ -505,7 +530,8 @@ class RagPipelineTaskStopApi(Resource):
Stop workflow task Stop workflow task
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
@ -525,7 +551,8 @@ class PublishedRagPipelineApi(Resource):
Get published pipeline Get published pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
if not pipeline.is_published: if not pipeline.is_published:
return None return None
@ -545,7 +572,8 @@ class PublishedRagPipelineApi(Resource):
Publish workflow Publish workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
@ -580,7 +608,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
Get default block config Get default block config
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
# Get default block configs # Get default block configs
@ -588,8 +617,12 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
return rag_pipeline_service.get_default_block_configs() return rag_pipeline_service.get_default_block_configs()
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource): class DefaultRagPipelineBlockConfigApi(Resource):
@api.expect(parser_default)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -599,12 +632,11 @@ class DefaultRagPipelineBlockConfigApi(Resource):
Get default block config Get default block config
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() args = parser_default.parse_args()
parser.add_argument("q", type=str, location="args")
args = parser.parse_args()
q = args.get("q") q = args.get("q")
@ -620,8 +652,18 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
parser_wf = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
class PublishedAllRagPipelineApi(Resource): class PublishedAllRagPipelineApi(Resource):
@api.expect(parser_wf)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -631,17 +673,13 @@ class PublishedAllRagPipelineApi(Resource):
""" """
Get published workflows Get published workflows
""" """
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() args = parser_wf.parse_args()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") page = args["page"]
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") limit = args["limit"]
parser.add_argument("user_id", type=str, required=False, location="args")
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
args = parser.parse_args()
page = int(args.get("page", 1))
limit = int(args.get("limit", 10))
user_id = args.get("user_id") user_id = args.get("user_id")
named_only = args.get("named_only", False) named_only = args.get("named_only", False)
@ -669,8 +707,16 @@ class PublishedAllRagPipelineApi(Resource):
} }
parser_wf_id = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource): class RagPipelineByIdApi(Resource):
@api.expect(parser_wf_id)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -681,20 +727,17 @@ class RagPipelineByIdApi(Resource):
Update workflow attributes Update workflow attributes
""" """
# Check permission # Check permission
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() args = parser_wf_id.parse_args()
parser.add_argument("marked_name", type=str, required=False, location="json")
parser.add_argument("marked_comment", type=str, required=False, location="json")
args = parser.parse_args()
# Validate name and comment length # Validate name and comment length
if args.marked_name and len(args.marked_name) > 20: if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters") raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100: if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters") raise ValueError("Marked comment cannot exceed 100 characters")
args = parser.parse_args()
# Prepare update data # Prepare update data
update_data = {} update_data = {}
@ -727,22 +770,22 @@ class RagPipelineByIdApi(Resource):
return workflow return workflow
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource): class PublishedRagPipelineSecondStepApi(Resource):
@api.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get second step parameters of rag pipeline Get second step parameters of rag pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor args = parser_parameters.parse_args()
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
node_id = args.get("node_id") node_id = args.get("node_id")
if not node_id: if not node_id:
raise ValueError("Node ID is required") raise ValueError("Node ID is required")
@ -755,20 +798,17 @@ class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource): class PublishedRagPipelineFirstStepApi(Resource):
@api.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get first step parameters of rag pipeline Get first step parameters of rag pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor args = parser_parameters.parse_args()
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
node_id = args.get("node_id") node_id = args.get("node_id")
if not node_id: if not node_id:
raise ValueError("Node ID is required") raise ValueError("Node ID is required")
@ -781,20 +821,17 @@ class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource): class DraftRagPipelineFirstStepApi(Resource):
@api.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get first step parameters of rag pipeline Get first step parameters of rag pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor args = parser_parameters.parse_args()
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
node_id = args.get("node_id") node_id = args.get("node_id")
if not node_id: if not node_id:
raise ValueError("Node ID is required") raise ValueError("Node ID is required")
@ -807,20 +844,17 @@ class DraftRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource): class DraftRagPipelineSecondStepApi(Resource):
@api.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get second step parameters of rag pipeline Get second step parameters of rag pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor args = parser_parameters.parse_args()
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
node_id = args.get("node_id") node_id = args.get("node_id")
if not node_id: if not node_id:
raise ValueError("Node ID is required") raise ValueError("Node ID is required")
@ -832,8 +866,16 @@ class DraftRagPipelineSecondStepApi(Resource):
} }
parser_wf_run = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource): class RagPipelineWorkflowRunListApi(Resource):
@api.expect(parser_wf_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -843,10 +885,7 @@ class RagPipelineWorkflowRunListApi(Resource):
""" """
Get workflow run list Get workflow run list
""" """
parser = reqparse.RequestParser() args = parser_wf_run.parse_args()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args) result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
@ -880,7 +919,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@marshal_with(workflow_run_node_execution_list_fields) @marshal_with(workflow_run_node_execution_list_fields)
def get(self, pipeline: Pipeline, run_id): def get(self, pipeline: Pipeline, run_id: str):
""" """
Get workflow run node execution list Get workflow run node execution list
""" """
@ -903,14 +942,8 @@ class DatasourceListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user _, current_tenant_id = current_account_with_tenant()
if not isinstance(user, Account): return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(current_tenant_id))
raise Forbidden()
tenant_id = user.current_tenant_id
if not tenant_id:
raise Forbidden()
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
@ -940,9 +973,8 @@ class RagPipelineTransformApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, dataset_id): def post(self, dataset_id: str):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden()
if not (current_user.has_edit_permission or current_user.is_dataset_operator): if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden() raise Forbidden()
@ -953,26 +985,30 @@ class RagPipelineTransformApi(Resource):
return result return result
parser_var = (
reqparse.RequestParser()
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info", type=dict, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("start_node_title", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource): class RagPipelineDatasourceVariableApi(Resource):
@api.expect(parser_var)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_run_node_execution_fields) @marshal_with(workflow_run_node_execution_fields)
def post(self, pipeline: Pipeline): def post(self, pipeline: Pipeline):
""" """
Set datasource variables Set datasource variables
""" """
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
raise Forbidden() args = parser_var.parse_args()
parser = reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info", type=dict, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("start_node_title", type=str, required=True, location="json")
args = parser.parse_args()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.set_datasource_variables( workflow_node_execution = rag_pipeline_service.set_datasource_variables(

View File

@ -31,17 +31,19 @@ class WebsiteCrawlApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"provider", .add_argument(
type=str, "provider",
choices=["firecrawl", "watercrawl", "jinareader"], type=str,
required=True, choices=["firecrawl", "watercrawl", "jinareader"],
nullable=True, required=True,
location="json", nullable=True,
location="json",
)
.add_argument("url", type=str, required=True, nullable=True, location="json")
.add_argument("options", type=dict, required=True, nullable=True, location="json")
) )
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
# Create typed request and validate # Create typed request and validate
@ -70,8 +72,7 @@ class WebsiteCrawlStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, job_id: str): def get(self, job_id: str):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args" "provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -3,8 +3,7 @@ from functools import wraps
from controllers.console.datasets.error import PipelineNotFoundError from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user from libs.login import current_account_with_tenant
from models.account import Account
from models.dataset import Pipeline from models.dataset import Pipeline
@ -17,8 +16,7 @@ def get_rag_pipeline(
if not kwargs.get("pipeline_id"): if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters") raise ValueError("missing pipeline_id in path parameters")
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("current_user is not an account")
pipeline_id = kwargs.get("pipeline_id") pipeline_id = kwargs.get("pipeline_id")
pipeline_id = str(pipeline_id) pipeline_id = str(pipeline_id)
@ -27,7 +25,7 @@ def get_rag_pipeline(
pipeline = ( pipeline = (
db.session.query(Pipeline) db.session.query(Pipeline)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id) .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
.first() .first()
) )

View File

@ -81,11 +81,13 @@ class ChatTextApi(InstalledAppResource):
app_model = installed_app.app app_model = installed_app.app
try: try:
parser = reqparse.RequestParser() parser = (
parser.add_argument("message_id", type=str, required=False, location="json") reqparse.RequestParser()
parser.add_argument("voice", type=str, location="json") .add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("text", type=str, location="json") .add_argument("voice", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json") .add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args() args = parser.parse_args()
message_id = args.get("message_id", None) message_id = args.get("message_id", None)

View File

@ -49,12 +49,14 @@ class CompletionApi(InstalledAppResource):
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, location="json") reqparse.RequestParser()
parser.add_argument("query", type=str, location="json", default="") .add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json") .add_argument("query", type=str, location="json", default="")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") .add_argument("files", type=list, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
@ -121,13 +123,15 @@ class ChatApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, location="json") reqparse.RequestParser()
parser.add_argument("query", type=str, required=True, location="json") .add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json") .add_argument("query", type=str, required=True, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json") .add_argument("files", type=list, required=False, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") .add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") .add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args() args = parser.parse_args()
args["auto_generate_name"] = False args["auto_generate_name"] = False

View File

@ -31,10 +31,12 @@ class ConversationListApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("last_id", type=uuid_value, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") .add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
)
args = parser.parse_args() args = parser.parse_args()
pinned = None pinned = None
@ -94,9 +96,11 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id) conversation_id = str(c_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("name", type=str, required=False, location="json") reqparse.RequestParser()
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") .add_argument("name", type=str, required=False, location="json")
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:

View File

@ -6,31 +6,29 @@ from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_, select from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.console import api from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields from fields.installed_app_fields import installed_app_list_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models import Account, App, InstalledApp, RecommendedApp from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource): class InstalledAppsListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(installed_app_list_fields) @marshal_with(installed_app_list_fields)
def get(self): def get(self):
app_id = request.args.get("app_id", default=None, type=str) app_id = request.args.get("app_id", default=None, type=str)
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
if app_id: if app_id:
installed_apps = db.session.scalars( installed_apps = db.session.scalars(
@ -68,31 +66,26 @@ class InstalledAppsListApi(Resource):
# Pre-filter out apps without setting or with sso_verified # Pre-filter out apps without setting or with sso_verified
filtered_installed_apps = [] filtered_installed_apps = []
app_id_to_app_code = {}
for installed_app in installed_app_list: for installed_app in installed_app_list:
app_id = installed_app["app"].id app_id = installed_app["app"].id
webapp_setting = webapp_settings.get(app_id) webapp_setting = webapp_settings.get(app_id)
if not webapp_setting or webapp_setting.access_mode == "sso_verified": if not webapp_setting or webapp_setting.access_mode == "sso_verified":
continue continue
app_code = AppService.get_app_code_by_id(str(app_id))
app_id_to_app_code[app_id] = app_code
filtered_installed_apps.append(installed_app) filtered_installed_apps.append(installed_app)
app_codes = list(app_id_to_app_code.values())
# Batch permission check # Batch permission check
app_ids = [installed_app["app"].id for installed_app in filtered_installed_apps]
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps( permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
user_id=user_id, user_id=user_id,
app_codes=app_codes, app_ids=app_ids,
) )
# Keep only allowed apps # Keep only allowed apps
res = [] res = []
for installed_app in filtered_installed_apps: for installed_app in filtered_installed_apps:
app_id = installed_app["app"].id app_id = installed_app["app"].id
app_code = app_id_to_app_code[app_id] if permissions.get(app_id):
if permissions.get(app_code):
res.append(installed_app) res.append(installed_app)
installed_app_list = res installed_app_list = res
@ -112,17 +105,15 @@ class InstalledAppsListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args() args = parser.parse_args()
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
if recommended_app is None: if recommended_app is None:
raise NotFound("App not found") raise NotFound("App not found")
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
app = db.session.query(App).where(App.id == args["app_id"]).first() app = db.session.query(App).where(App.id == args["app_id"]).first()
if app is None: if app is None:
@ -154,6 +145,7 @@ class InstalledAppsListApi(Resource):
return {"message": "App installed successfully"} return {"message": "App installed successfully"}
@console_ns.route("/installed-apps/<uuid:installed_app_id>")
class InstalledAppApi(InstalledAppResource): class InstalledAppApi(InstalledAppResource):
""" """
update and delete an installed app update and delete an installed app
@ -161,9 +153,8 @@ class InstalledAppApi(InstalledAppResource):
""" """
def delete(self, installed_app): def delete(self, installed_app):
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("current_user must be an Account instance") if installed_app.app_owner_tenant_id == current_tenant_id:
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant") raise BadRequest("You can't uninstall an app owned by the current tenant")
db.session.delete(installed_app) db.session.delete(installed_app)
@ -172,8 +163,7 @@ class InstalledAppApi(InstalledAppResource):
return {"result": "success", "message": "App uninstalled successfully"}, 204 return {"result": "success", "message": "App uninstalled successfully"}, 204
def patch(self, installed_app): def patch(self, installed_app):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
parser.add_argument("is_pinned", type=inputs.boolean)
args = parser.parse_args() args = parser.parse_args()
commit_args = False commit_args = False
@ -185,7 +175,3 @@ class InstalledAppApi(InstalledAppResource):
db.session.commit() db.session.commit()
return {"result": "success", "message": "App info updated successfully"} return {"result": "success", "message": "App info updated successfully"}
api.add_resource(InstalledAppsListApi, "/installed-apps")
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")

View File

@ -23,8 +23,7 @@ from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import current_user from libs.login import current_account_with_tenant
from models import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError from services.errors.app import MoreLikeThisDisabledError
@ -48,21 +47,22 @@ logger = logging.getLogger(__name__)
class MessageListApi(InstalledAppResource): class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") reqparse.RequestParser()
parser.add_argument("first_id", type=uuid_value, location="args") .add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") .add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return MessageService.pagination_by_first_id( return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
) )
@ -78,18 +78,19 @@ class MessageListApi(InstalledAppResource):
) )
class MessageFeedbackApi(InstalledAppResource): class MessageFeedbackApi(InstalledAppResource):
def post(self, installed_app, message_id): def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") reqparse.RequestParser()
parser.add_argument("content", type=str, location="json") .add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
.add_argument("content", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
MessageService.create_feedback( MessageService.create_feedback(
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
@ -109,14 +110,14 @@ class MessageFeedbackApi(InstalledAppResource):
) )
class MessageMoreLikeThisApi(InstalledAppResource): class MessageMoreLikeThisApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
) )
args = parser.parse_args() args = parser.parse_args()
@ -124,8 +125,6 @@ class MessageMoreLikeThisApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate_more_like_this( response = AppGenerateService.generate_more_like_this(
app_model=app_model, app_model=app_model,
user=current_user, user=current_user,
@ -159,6 +158,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
) )
class MessageSuggestedQuestionApi(InstalledAppResource): class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -167,8 +167,6 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
message_id = str(message_id) message_id = str(message_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer( questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
) )

View File

@ -1,7 +1,7 @@
from flask_restx import marshal_with from flask_restx import marshal_with
from controllers.common import fields from controllers.common import fields
from controllers.console import api from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
@ -9,6 +9,7 @@ from models.model import AppMode, InstalledApp
from services.app_service import AppService from services.app_service import AppService
@console_ns.route("/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters")
class AppParameterApi(InstalledAppResource): class AppParameterApi(InstalledAppResource):
"""Resource for app variables.""" """Resource for app variables."""
@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
class ExploreAppMetaApi(InstalledAppResource): class ExploreAppMetaApi(InstalledAppResource):
def get(self, installed_app: InstalledApp): def get(self, installed_app: InstalledApp):
"""Get app meta""" """Get app meta"""
@ -46,9 +48,3 @@ class ExploreAppMetaApi(InstalledAppResource):
if not app_model: if not app_model:
raise ValueError("App not found") raise ValueError("App not found")
return AppService().get_app_meta(app_model) return AppService().get_app_meta(app_model)
api.add_resource(
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
)
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField from libs.helper import AppIconUrlField
from libs.login import current_user, login_required from libs.login import current_user, login_required
@ -35,15 +35,18 @@ recommended_app_list_fields = {
} }
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args")
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource): class RecommendedAppListApi(Resource):
@api.expect(parser_apps)
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(recommended_app_list_fields) @marshal_with(recommended_app_list_fields)
def get(self): def get(self):
# language args # language args
parser = reqparse.RequestParser() args = parser_apps.parse_args()
parser.add_argument("language", type=str, location="args")
args = parser.parse_args()
language = args.get("language") language = args.get("language")
if language and language in languages: if language and language in languages:
@ -56,13 +59,10 @@ class RecommendedAppListApi(Resource):
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix) return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
@console_ns.route("/explore/apps/<uuid:app_id>")
class RecommendedAppApi(Resource): class RecommendedAppApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
app_id = str(app_id) app_id = str(app_id)
return RecommendedAppService.get_recommend_app_detail(app_id) return RecommendedAppService.get_recommend_app_detail(app_id)
api.add_resource(RecommendedAppListApi, "/explore/apps")
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")

View File

@ -2,13 +2,12 @@ from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user from libs.login import current_account_with_tenant
from models import Account
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService from services.saved_message_service import SavedMessageService
@ -25,6 +24,7 @@ message_fields = {
} }
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource): class SavedMessageListApi(InstalledAppResource):
saved_message_infinite_scroll_pagination_fields = { saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer, "limit": fields.Integer,
@ -34,31 +34,30 @@ class SavedMessageListApi(InstalledAppResource):
@marshal_with(saved_message_infinite_scroll_pagination_fields) @marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("last_id", type=uuid_value, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") .add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args() args = parser.parse_args()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
def post(self, installed_app): def post(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
parser.add_argument("message_id", type=uuid_value, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.save(app_model, current_user, args["message_id"]) SavedMessageService.save(app_model, current_user, args["message_id"])
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -66,8 +65,12 @@ class SavedMessageListApi(InstalledAppResource):
return {"result": "success"} return {"result": "success"}
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>", endpoint="installed_app_saved_message"
)
class SavedMessageApi(InstalledAppResource): class SavedMessageApi(InstalledAppResource):
def delete(self, installed_app, message_id): def delete(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
message_id = str(message_id) message_id = str(message_id)
@ -75,20 +78,6 @@ class SavedMessageApi(InstalledAppResource):
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.delete(app_model, current_user, message_id) SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204 return {"result": "success"}, 204
api.add_resource(
SavedMessageListApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages",
endpoint="installed_app_saved_messages",
)
api.add_resource(
SavedMessageApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
endpoint="installed_app_saved_message",
)

View File

@ -22,7 +22,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper from libs import helper
from libs.login import current_user from libs.login import current_account_with_tenant
from models.model import AppMode, InstalledApp from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
@ -38,6 +38,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
""" """
Run workflow Run workflow
""" """
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if not app_model: if not app_model:
raise NotWorkflowAppError() raise NotWorkflowAppError()
@ -45,11 +46,12 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("files", type=list, required=False, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
assert current_user is not None
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
@ -85,7 +87,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility # Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check) # Legacy stop flag mechanism (without user check)

View File

@ -2,16 +2,14 @@ from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar from typing import Concatenate, ParamSpec, TypeVar
from flask_login import current_user
from flask_restx import Resource from flask_restx import Resource
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import InstalledApp from models import InstalledApp
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -24,11 +22,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
def decorator(view: Callable[Concatenate[InstalledApp, P], R]): def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.where( .where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
)
.first() .first()
) )
@ -54,13 +51,13 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
def decorator(view: Callable[Concatenate[InstalledApp, P], R]): def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
feature = FeatureService.get_system_features() feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled: if feature.webapp_auth.enabled:
app_id = installed_app.app_id app_id = installed_app.app_id
app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=str(current_user.id), user_id=str(current_user.id),
app_code=app_code, app_id=app_id,
) )
if not res: if not res:
raise AppAccessDeniedError() raise AppAccessDeniedError()

Some files were not shown because too many files have changed in this diff Show More