Merge branch 'main' into feat/memory-orchestration-fed

This commit is contained in:
zxhlyh 2025-10-11 14:08:09 +08:00
commit 0e545d7be5
717 changed files with 13706 additions and 7996 deletions

View File

@ -1,4 +1,4 @@
FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev && apt-get -y install libgmp-dev libmpfr-dev libmpc-dev

View File

@ -1,5 +1,8 @@
blank_issues_enabled: false blank_issues_enabled: false
contact_links: contact_links:
- name: "\U0001F510 Security Vulnerabilities"
url: "https://github.com/langgenius/dify/security/advisories/new"
about: Report security vulnerabilities through GitHub Security Advisories to ensure responsible disclosure. 💡 Please do not report security vulnerabilities in public issues.
- name: "\U0001F4A1 Model Providers & Plugins" - name: "\U0001F4A1 Model Providers & Plugins"
url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose" url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose"
about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details. about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details.

View File

@ -15,10 +15,12 @@ jobs:
# Use uv to ensure we have the same ruff version in CI and locally. # Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@v6 - uses: astral-sh/setup-uv@v6
with: with:
python-version: "3.12" python-version: "3.11"
- run: | - run: |
cd api cd api
uv sync --dev uv sync --dev
# fmt first to avoid line too long
uv run ruff format ..
# Fix lint errors # Fix lint errors
uv run ruff check --fix . uv run ruff check --fix .
# Format code # Format code
@ -28,6 +30,8 @@ jobs:
run: | run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
# Convert Optional[T] to T | None (ignoring quoted types) # Convert Optional[T] to T | None (ignoring quoted types)
cat > /tmp/optional-rule.yml << 'EOF' cat > /tmp/optional-rule.yml << 'EOF'
id: convert-optional-to-union id: convert-optional-to-union

View File

@ -4,10 +4,10 @@ on:
push: push:
branches: branches:
- "main" - "main"
- "deploy/dev" - "deploy/**"
- "deploy/enterprise"
- "build/**" - "build/**"
- "release/e-*" - "release/e-*"
- "hotfix/**"
tags: tags:
- "*" - "*"

View File

@ -18,7 +18,7 @@ jobs:
- name: Deploy to server - name: Deploy to server
uses: appleboy/ssh-action@v0.1.8 uses: appleboy/ssh-action@v0.1.8
with: with:
host: ${{ secrets.RAG_SSH_HOST }} host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }} username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }} key: ${{ secrets.SSH_PRIVATE_KEY }}
script: | script: |

View File

@ -1,4 +1,4 @@
name: Deploy RAG Dev name: Deploy Trigger Dev
permissions: permissions:
contents: read contents: read
@ -7,7 +7,7 @@ on:
workflow_run: workflow_run:
workflows: ["Build and Push API & Web"] workflows: ["Build and Push API & Web"]
branches: branches:
- "deploy/rag-dev" - "deploy/trigger-dev"
types: types:
- completed - completed
@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: | if: |
github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/rag-dev' github.event.workflow_run.head_branch == 'deploy/trigger-dev'
steps: steps:
- name: Deploy to server - name: Deploy to server
uses: appleboy/ssh-action@v0.1.8 uses: appleboy/ssh-action@v0.1.8
with: with:
host: ${{ secrets.RAG_SSH_HOST }} host: ${{ secrets.TRIGGER_SSH_HOST }}
username: ${{ secrets.SSH_USER }} username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }} key: ${{ secrets.SSH_PRIVATE_KEY }}
script: | script: |

View File

@ -4,84 +4,51 @@
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management. Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
The codebase consists of: The codebase is split into:
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture - **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19 - **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19
- **Docker deployment** (`/docker`): Containerized deployment configurations - **Docker deployment** (`/docker`): Containerized deployment configurations
## Development Commands ## Backend Workflow
### Backend (API) - Run backend CLI commands through `uv run --project api <command>`.
All Python commands must be prefixed with `uv run --project api`: - Backend QA gate requires passing `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before review.
```bash - Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
# Start development servers
./dev/start-api # Start API server
./dev/start-worker # Start Celery worker
# Run tests - Integration tests are CI-only and are not expected to run in the local environment.
uv run --project api pytest # Run all tests
uv run --project api pytest tests/unit_tests/ # Unit tests only
uv run --project api pytest tests/integration_tests/ # Integration tests
# Code quality ## Frontend Workflow
./dev/reformat # Run all formatters and linters
uv run --project api ruff check --fix ./ # Fix linting issues
uv run --project api ruff format ./ # Format code
uv run --directory api basedpyright # Type checking
```
### Frontend (Web)
```bash ```bash
cd web cd web
pnpm lint # Run ESLint pnpm lint
pnpm eslint-fix # Fix ESLint issues pnpm lint:fix
pnpm test # Run Jest tests pnpm test
``` ```
## Testing Guidelines ## Testing & Quality Practices
### Backend Testing - Follow TDD: red → green → refactor.
- Use `pytest` for backend tests with Arrange-Act-Assert structure.
- Enforce strong typing; avoid `Any` and prefer explicit type annotations.
- Write self-documenting code; only add comments that explain intent.
- Use `pytest` for all backend tests ## Language Style
- Write tests first (TDD approach)
- Test structure: Arrange-Act-Assert
## Code Style Requirements - **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types.
### Python ## General Practices
- Use type hints for all functions and class attributes - Prefer editing existing files; add new documentation only when requested.
- No `Any` types unless absolutely necessary - Inject dependencies through constructors and preserve clean architecture boundaries.
- Implement special methods (`__repr__`, `__str__`) appropriately - Handle errors with domain-specific exceptions at the correct layer.
### TypeScript/JavaScript ## Project Conventions
- Strict TypeScript configuration - Backend architecture adheres to DDD and Clean Architecture principles.
- ESLint with Prettier integration - Async work runs through Celery with Redis as the broker.
- Avoid `any` type - Frontend user-facing strings must use `web/i18n/en-US/`; avoid hardcoded text.
## Important Notes
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
- **Comments**: Only write meaningful comments that explain "why", not "what"
- **File Creation**: Always prefer editing existing files over creating new ones
- **Documentation**: Don't create documentation files unless explicitly requested
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
## Common Development Tasks
### Adding a New API Endpoint
1. Create controller in `/api/controllers/`
1. Add service logic in `/api/services/`
1. Update routes in controller's `__init__.py`
1. Write tests in `/api/tests/`
## Project-Specific Conventions
- All async tasks use Celery with Redis as broker
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.

View File

@ -26,7 +26,6 @@ prepare-web:
@echo "🌐 Setting up web environment..." @echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists" @cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install @cd web && pnpm install
@cd web && pnpm build
@echo "✅ Web environment prepared (not started)" @echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment # Step 3: Prepare API environment

View File

@ -40,18 +40,18 @@
<p align="center"> <p align="center">
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a> <a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
<a href="./README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a> <a href="./docs/zh-TW/README.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a> <a href="./docs/zh-CN/README.md"><img alt="简体中文文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a> <a href="./docs/ja-JP/README.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a> <a href="./docs/es-ES/README.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a> <a href="./docs/fr-FR/README.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a> <a href="./docs/tlh/README.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
<a href="./README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a> <a href="./docs/ko-KR/README.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
<a href="./README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a> <a href="./docs/ar-SA/README.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
<a href="./README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a> <a href="./docs/tr-TR/README.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
<a href="./README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a> <a href="./docs/vi-VN/README.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
<a href="./README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a> <a href="./docs/de-DE/README.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
<a href="./README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a> <a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
</p> </p>
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production. Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.

View File

@ -343,6 +343,15 @@ OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false OCEANBASE_ENABLE_HYBRID_SEARCH=false
# AlibabaCloud MySQL Vector configuration
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
ALIBABACLOUD_MYSQL_PORT=3306
ALIBABACLOUD_MYSQL_USER=root
ALIBABACLOUD_MYSQL_PASSWORD=root
ALIBABACLOUD_MYSQL_DATABASE=dify
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
ALIBABACLOUD_MYSQL_HNSW_M=6
# openGauss configuration # openGauss configuration
OPENGAUSS_HOST=127.0.0.1 OPENGAUSS_HOST=127.0.0.1
OPENGAUSS_PORT=6600 OPENGAUSS_PORT=6600
@ -408,6 +417,9 @@ SSRF_DEFAULT_TIME_OUT=5
SSRF_DEFAULT_CONNECT_TIME_OUT=5 SSRF_DEFAULT_CONNECT_TIME_OUT=5
SSRF_DEFAULT_READ_TIME_OUT=5 SSRF_DEFAULT_READ_TIME_OUT=5
SSRF_DEFAULT_WRITE_TIME_OUT=5 SSRF_DEFAULT_WRITE_TIME_OUT=5
SSRF_POOL_MAX_CONNECTIONS=100
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS=20
SSRF_POOL_KEEPALIVE_EXPIRY=5.0
BATCH_UPLOAD_LIMIT=10 BATCH_UPLOAD_LIMIT=10
KEYWORD_DATA_SOURCE_TYPE=database KEYWORD_DATA_SOURCE_TYPE=database
@ -418,10 +430,14 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10
# CODE EXECUTION CONFIGURATION # CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194 CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
CODE_EXECUTION_API_KEY=dify-sandbox CODE_EXECUTION_API_KEY=dify-sandbox
CODE_EXECUTION_SSL_VERIFY=True
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
CODE_MAX_NUMBER=9223372036854775807 CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808 CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_STRING_LENGTH=80000 CODE_MAX_STRING_LENGTH=400000
TEMPLATE_TRANSFORM_MAX_LENGTH=80000 TEMPLATE_TRANSFORM_MAX_LENGTH=400000
CODE_MAX_STRING_ARRAY_LENGTH=30 CODE_MAX_STRING_ARRAY_LENGTH=30
CODE_MAX_OBJECT_ARRAY_LENGTH=30 CODE_MAX_OBJECT_ARRAY_LENGTH=30
CODE_MAX_NUMBER_ARRAY_LENGTH=1000 CODE_MAX_NUMBER_ARRAY_LENGTH=1000
@ -461,7 +477,6 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5 WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800 MAX_VARIABLE_SIZE=204800
# GraphEngine Worker Pool Configuration # GraphEngine Worker Pool Configuration

View File

@ -81,7 +81,6 @@ ignore = [
"SIM113", # enumerate-for-loop "SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements "SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false "SIM210", # if-expr-with-true-false
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
] ]
[lint.per-file-ignores] [lint.per-file-ignores]

View File

@ -80,10 +80,10 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash ```bash
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
``` ```
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal: Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
```bash ```bash
uv run celery -A app.celery beat uv run celery -A app.celery beat

View File

@ -1,20 +1,11 @@
import logging
import psycogreen.gevent as pscycogreen_gevent # type: ignore import psycogreen.gevent as pscycogreen_gevent # type: ignore
from grpc.experimental import gevent as grpc_gevent # type: ignore from grpc.experimental import gevent as grpc_gevent # type: ignore
_logger = logging.getLogger(__name__)
def _log(message: str):
_logger.debug(message)
# grpc gevent # grpc gevent
grpc_gevent.init_gevent() grpc_gevent.init_gevent()
_log("gRPC patched with gevent.") print("gRPC patched with gevent.", flush=True) # noqa: T201
pscycogreen_gevent.patch_psycopg() pscycogreen_gevent.patch_psycopg()
_log("psycopg2 patched with gevent.") print("psycopg2 patched with gevent.", flush=True) # noqa: T201
from app import app, celery from app import app, celery

View File

@ -10,6 +10,7 @@ from flask import current_app
from pydantic import TypeAdapter from pydantic import TypeAdapter
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
@ -61,31 +62,30 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip(): if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red")) click.echo(click.style("Passwords do not match.", fg="red"))
return return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
account = db.session.query(Account).where(Account.email == email).one_or_none() if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account: try:
click.echo(click.style(f"Account not found for email: {email}", fg="red")) valid_password(new_password)
return except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
try: # generate password salt
valid_password(new_password) salt = secrets.token_bytes(16)
except: base64_salt = base64.b64encode(salt).decode()
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
# generate password salt # encrypt password with salt
salt = secrets.token_bytes(16) password_hashed = hash_password(new_password, salt)
base64_salt = base64.b64encode(salt).decode() base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
# encrypt password with salt account.password_salt = base64_salt
password_hashed = hash_password(new_password, salt) AccountService.reset_login_error_rate_limit(email)
base64_password_hashed = base64.b64encode(password_hashed).decode() click.echo(click.style("Password reset successfully.", fg="green"))
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
AccountService.reset_login_error_rate_limit(email)
click.echo(click.style("Password reset successfully.", fg="green"))
@click.command("reset-email", help="Reset the account email.") @click.command("reset-email", help="Reset the account email.")
@ -100,22 +100,21 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip(): if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red")) click.echo(click.style("New emails do not match.", fg="red"))
return return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
account = db.session.query(Account).where(Account.email == email).one_or_none() if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account: try:
click.echo(click.style(f"Account not found for email: {email}", fg="red")) email_validate(new_email)
return except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
try: account.email = new_email
email_validate(new_email) click.echo(click.style("Email updated successfully.", fg="green"))
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account.email = new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
@click.command( @click.command(
@ -139,25 +138,24 @@ def reset_encrypt_key_pair():
if dify_config.EDITION != "SELF_HOSTED": if dify_config.EDITION != "SELF_HOSTED":
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red")) click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
return return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
tenants = session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
return
tenants = db.session.query(Tenant).all() tenant.encrypt_public_key = generate_key_pair(tenant.id)
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
return
tenant.encrypt_public_key = generate_key_pair(tenant.id) session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() click.echo(
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() click.style(
db.session.commit() f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
fg="green",
click.echo( )
click.style(
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
fg="green",
) )
)
@click.command("vdb-migrate", help="Migrate vector db.") @click.command("vdb-migrate", help="Migrate vector db.")
@ -182,14 +180,15 @@ def migrate_annotation_vector_database():
try: try:
# get apps info # get apps info
per_page = 50 per_page = 50
apps = ( with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
db.session.query(App) apps = (
.where(App.status == "normal") session.query(App)
.order_by(App.created_at.desc()) .where(App.status == "normal")
.limit(per_page) .order_by(App.created_at.desc())
.offset((page - 1) * per_page) .limit(per_page)
.all() .offset((page - 1) * per_page)
) .all()
)
if not apps: if not apps:
break break
except SQLAlchemyError: except SQLAlchemyError:
@ -203,26 +202,27 @@ def migrate_annotation_vector_database():
) )
try: try:
click.echo(f"Creating app annotation index: {app.id}") click.echo(f"Creating app annotation index: {app.id}")
app_annotation_setting = ( with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() app_annotation_setting = (
) session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)
if not app_annotation_setting: if not app_annotation_setting:
skipped_count = skipped_count + 1 skipped_count = skipped_count + 1
click.echo(f"App annotation setting disabled: {app.id}") click.echo(f"App annotation setting disabled: {app.id}")
continue continue
# get dataset_collection_binding info # get dataset_collection_binding info
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first() .first()
) )
if not dataset_collection_binding: if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}") click.echo(f"App annotation collection binding not found: {app.id}")
continue continue
annotations = db.session.scalars( annotations = session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all() ).all()
dataset = Dataset( dataset = Dataset(
id=app.id, id=app.id,
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@ -1448,41 +1448,52 @@ def transform_datasource_credentials():
notion_credentials_tenant_mapping[tenant_id] = [] notion_credentials_tenant_mapping[tenant_id] = []
notion_credentials_tenant_mapping[tenant_id].append(notion_credential) notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items(): for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
# check notion plugin is installed tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
installed_plugins = installer_manager.list_plugins(tenant_id) if not tenant:
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] continue
if notion_plugin_id not in installed_plugins_ids: try:
if notion_plugin_unique_identifier: # check notion plugin is installed
# install notion plugin installed_plugins = installer_manager.list_plugins(tenant_id)
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier]) installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
auth_count = 0 if notion_plugin_id not in installed_plugins_ids:
for notion_tenant_credential in notion_tenant_credentials: if notion_plugin_unique_identifier:
auth_count += 1 # install notion plugin
# get credential oauth params PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
access_token = notion_tenant_credential.access_token auth_count = 0
# notion info for notion_tenant_credential in notion_tenant_credentials:
notion_info = notion_tenant_credential.source_info auth_count += 1
workspace_id = notion_info.get("workspace_id") # get credential oauth params
workspace_name = notion_info.get("workspace_name") access_token = notion_tenant_credential.access_token
workspace_icon = notion_info.get("workspace_icon") # notion info
new_credentials = { notion_info = notion_tenant_credential.source_info
"integration_secret": encrypter.encrypt_token(tenant_id, access_token), workspace_id = notion_info.get("workspace_id")
"workspace_id": workspace_id, workspace_name = notion_info.get("workspace_name")
"workspace_name": workspace_name, workspace_icon = notion_info.get("workspace_icon")
"workspace_icon": workspace_icon, new_credentials = {
} "integration_secret": encrypter.encrypt_token(tenant_id, access_token),
datasource_provider = DatasourceProvider( "workspace_id": workspace_id,
provider="notion_datasource", "workspace_name": workspace_name,
tenant_id=tenant_id, "workspace_icon": workspace_icon,
plugin_id=notion_plugin_id, }
auth_type=oauth_credential_type.value, datasource_provider = DatasourceProvider(
encrypted_credentials=new_credentials, provider="notion_datasource",
name=f"Auth {auth_count}", tenant_id=tenant_id,
avatar_url=workspace_icon or "default", plugin_id=notion_plugin_id,
is_default=False, auth_type=oauth_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url=workspace_icon or "default",
is_default=False,
)
db.session.add(datasource_provider)
deal_notion_count += 1
except Exception as e:
click.echo(
click.style(
f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
)
) )
db.session.add(datasource_provider) continue
deal_notion_count += 1
db.session.commit() db.session.commit()
# deal firecrawl credentials # deal firecrawl credentials
deal_firecrawl_count = 0 deal_firecrawl_count = 0
@ -1495,37 +1506,48 @@ def transform_datasource_credentials():
firecrawl_credentials_tenant_mapping[tenant_id] = [] firecrawl_credentials_tenant_mapping[tenant_id] = []
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential) firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items(): for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
# check firecrawl plugin is installed tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
installed_plugins = installer_manager.list_plugins(tenant_id) if not tenant:
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] continue
if firecrawl_plugin_id not in installed_plugins_ids: try:
if firecrawl_plugin_unique_identifier: # check firecrawl plugin is installed
# install firecrawl plugin installed_plugins = installer_manager.list_plugins(tenant_id)
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier]) installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if firecrawl_plugin_id not in installed_plugins_ids:
if firecrawl_plugin_unique_identifier:
# install firecrawl plugin
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
auth_count = 0 auth_count = 0
for firecrawl_tenant_credential in firecrawl_tenant_credentials: for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1 auth_count += 1
# get credential api key # get credential api key
credentials_json = json.loads(firecrawl_tenant_credential.credentials) credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key") api_key = credentials_json.get("config", {}).get("api_key")
base_url = credentials_json.get("config", {}).get("base_url") base_url = credentials_json.get("config", {}).get("base_url")
new_credentials = { new_credentials = {
"firecrawl_api_key": api_key, "firecrawl_api_key": api_key,
"base_url": base_url, "base_url": base_url,
} }
datasource_provider = DatasourceProvider( datasource_provider = DatasourceProvider(
provider="firecrawl", provider="firecrawl",
tenant_id=tenant_id, tenant_id=tenant_id,
plugin_id=firecrawl_plugin_id, plugin_id=firecrawl_plugin_id,
auth_type=api_key_credential_type.value, auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials, encrypted_credentials=new_credentials,
name=f"Auth {auth_count}", name=f"Auth {auth_count}",
avatar_url="default", avatar_url="default",
is_default=False, is_default=False,
)
db.session.add(datasource_provider)
deal_firecrawl_count += 1
except Exception as e:
click.echo(
click.style(
f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
)
) )
db.session.add(datasource_provider) continue
deal_firecrawl_count += 1
db.session.commit() db.session.commit()
# deal jina credentials # deal jina credentials
deal_jina_count = 0 deal_jina_count = 0
@ -1538,36 +1560,45 @@ def transform_datasource_credentials():
jina_credentials_tenant_mapping[tenant_id] = [] jina_credentials_tenant_mapping[tenant_id] = []
jina_credentials_tenant_mapping[tenant_id].append(jina_credential) jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items(): for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
# check jina plugin is installed tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
installed_plugins = installer_manager.list_plugins(tenant_id) if not tenant:
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] continue
if jina_plugin_id not in installed_plugins_ids: try:
if jina_plugin_unique_identifier: # check jina plugin is installed
# install jina plugin installed_plugins = installer_manager.list_plugins(tenant_id)
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier) installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier]) if jina_plugin_id not in installed_plugins_ids:
if jina_plugin_unique_identifier:
# install jina plugin
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
auth_count = 0 auth_count = 0
for jina_tenant_credential in jina_tenant_credentials: for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1 auth_count += 1
# get credential api key # get credential api key
credentials_json = json.loads(jina_tenant_credential.credentials) credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key") api_key = credentials_json.get("config", {}).get("api_key")
new_credentials = { new_credentials = {
"integration_secret": api_key, "integration_secret": api_key,
} }
datasource_provider = DatasourceProvider( datasource_provider = DatasourceProvider(
provider="jina", provider="jina",
tenant_id=tenant_id, tenant_id=tenant_id,
plugin_id=jina_plugin_id, plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value, auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials, encrypted_credentials=new_credentials,
name=f"Auth {auth_count}", name=f"Auth {auth_count}",
avatar_url="default", avatar_url="default",
is_default=False, is_default=False,
)
db.session.add(datasource_provider)
deal_jina_count += 1
except Exception as e:
click.echo(
click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red")
) )
db.session.add(datasource_provider) continue
deal_jina_count += 1
db.session.commit() db.session.commit()
except Exception as e: except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

View File

@ -113,6 +113,21 @@ class CodeExecutionSandboxConfig(BaseSettings):
default=10.0, default=10.0,
) )
CODE_EXECUTION_POOL_MAX_CONNECTIONS: PositiveInt = Field(
description="Maximum number of concurrent connections for the code execution HTTP client",
default=100,
)
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field(
description="Maximum number of persistent keep-alive connections for the code execution HTTP client",
default=20,
)
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field(
description="Keep-alive expiry in seconds for idle connections (set to None to disable)",
default=5.0,
)
CODE_MAX_NUMBER: PositiveInt = Field( CODE_MAX_NUMBER: PositiveInt = Field(
description="Maximum allowed numeric value in code execution", description="Maximum allowed numeric value in code execution",
default=9223372036854775807, default=9223372036854775807,
@ -135,7 +150,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
CODE_MAX_STRING_LENGTH: PositiveInt = Field( CODE_MAX_STRING_LENGTH: PositiveInt = Field(
description="Maximum allowed length for strings in code execution", description="Maximum allowed length for strings in code execution",
default=80000, default=400_000,
) )
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field( CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
@ -153,6 +168,11 @@ class CodeExecutionSandboxConfig(BaseSettings):
default=1000, default=1000,
) )
CODE_EXECUTION_SSL_VERIFY: bool = Field(
description="Enable or disable SSL verification for code execution requests",
default=True,
)
class PluginConfig(BaseSettings): class PluginConfig(BaseSettings):
""" """
@ -342,11 +362,11 @@ class HttpConfig(BaseSettings):
) )
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field( HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60 ge=1, description="Maximum read timeout in seconds for HTTP requests", default=600
) )
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field( HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20 ge=1, description="Maximum write timeout in seconds for HTTP requests", default=600
) )
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
@ -404,6 +424,21 @@ class HttpConfig(BaseSettings):
default=5, default=5,
) )
SSRF_POOL_MAX_CONNECTIONS: PositiveInt = Field(
description="Maximum number of concurrent connections for the SSRF HTTP client",
default=100,
)
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field(
description="Maximum number of persistent keep-alive connections for the SSRF HTTP client",
default=20,
)
SSRF_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field(
description="Keep-alive expiry in seconds for idle SSRF connections (set to None to disable)",
default=5.0,
)
RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field( RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field(
description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers" description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers"
" when the app is behind a single trusted reverse proxy.", " when the app is behind a single trusted reverse proxy.",
@ -542,16 +577,16 @@ class WorkflowConfig(BaseSettings):
default=5, default=5,
) )
WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field(
description="Maximum allowed depth for nested parallel executions",
default=3,
)
MAX_VARIABLE_SIZE: PositiveInt = Field( MAX_VARIABLE_SIZE: PositiveInt = Field(
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.", description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
default=200 * 1024, default=200 * 1024,
) )
TEMPLATE_TRANSFORM_MAX_LENGTH: PositiveInt = Field(
description="Maximum number of characters allowed in Template Transform node output",
default=400_000,
)
# GraphEngine Worker Pool Configuration # GraphEngine Worker Pool Configuration
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field( GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
description="Minimum number of workers per GraphEngine instance", description="Minimum number of workers per GraphEngine instance",
@ -736,7 +771,7 @@ class MailConfig(BaseSettings):
MAIL_TEMPLATING_TIMEOUT: int = Field( MAIL_TEMPLATING_TIMEOUT: int = Field(
description=""" description="""
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates. Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
Only available in sandbox mode.""", Only available in sandbox mode.""",
default=3, default=3,
) )

View File

@ -18,6 +18,7 @@ from .storage.opendal_storage_config import OpenDALStorageConfig
from .storage.supabase_storage_config import SupabaseStorageConfig from .storage.supabase_storage_config import SupabaseStorageConfig
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
from .vdb.analyticdb_config import AnalyticdbConfig from .vdb.analyticdb_config import AnalyticdbConfig
from .vdb.baidu_vector_config import BaiduVectorDBConfig from .vdb.baidu_vector_config import BaiduVectorDBConfig
from .vdb.chroma_config import ChromaConfig from .vdb.chroma_config import ChromaConfig
@ -330,6 +331,7 @@ class MiddlewareConfig(
ClickzettaConfig, ClickzettaConfig,
HuaweiCloudConfig, HuaweiCloudConfig,
MilvusConfig, MilvusConfig,
AlibabaCloudMySQLConfig,
MyScaleConfig, MyScaleConfig,
OpenSearchConfig, OpenSearchConfig,
OracleConfig, OracleConfig,

View File

@ -0,0 +1,54 @@
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AlibabaCloudMySQLConfig(BaseSettings):
"""
Configuration settings for AlibabaCloud MySQL vector database
"""
ALIBABACLOUD_MYSQL_HOST: str = Field(
description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')",
default="localhost",
)
ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field(
description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)",
default=3306,
)
ALIBABACLOUD_MYSQL_USER: str = Field(
description="Username for authenticating with AlibabaCloud MySQL (default is 'root')",
default="root",
)
ALIBABACLOUD_MYSQL_PASSWORD: str = Field(
description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)",
default="",
)
ALIBABACLOUD_MYSQL_DATABASE: str = Field(
description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')",
default="dify",
)
ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field(
description="Maximum number of connections in the connection pool",
default=5,
)
ALIBABACLOUD_MYSQL_CHARSET: str = Field(
description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')",
default="utf8mb4",
)
ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field(
description="Distance function used for vector similarity search in AlibabaCloud MySQL "
"(e.g., 'cosine', 'euclidean')",
default="cosine",
)
ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field(
description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)",
default=6,
)

View File

@ -40,8 +40,12 @@ class OceanBaseVectorConfig(BaseSettings):
OCEANBASE_FULLTEXT_PARSER: str | None = Field( OCEANBASE_FULLTEXT_PARSER: str | None = Field(
description=( description=(
"Fulltext parser to use for text indexing. Options: 'japanese_ftparser' (Japanese), " "Fulltext parser to use for text indexing. "
"'thai_ftparser' (Thai), 'ik' (Chinese). Default is 'ik'" "Built-in options: 'ngram' (N-gram tokenizer for English/numbers), "
"'beng' (Basic English tokenizer), 'space' (Space-based tokenizer), "
"'ngram2' (Improved N-gram tokenizer), 'ik' (Chinese tokenizer). "
"External plugins (require installation): 'japanese_ftparser' (Japanese tokenizer), "
"'thai_ftparser' (Thai tokenizer). Default is 'ik'"
), ),
default="ik", default="ik",
) )

View File

@ -1,23 +1,24 @@
from enum import Enum from enum import StrEnum
from typing import Literal from typing import Literal
from pydantic import Field, PositiveInt from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
class AuthMethod(StrEnum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
class OpenSearchConfig(BaseSettings): class OpenSearchConfig(BaseSettings):
""" """
Configuration settings for OpenSearch Configuration settings for OpenSearch
""" """
class AuthMethod(Enum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
OPENSEARCH_HOST: str | None = Field( OPENSEARCH_HOST: str | None = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')", description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None, default=None,

View File

@ -5,7 +5,7 @@ import logging
import os import os
import time import time
import requests import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,10 +30,10 @@ class NacosHttpClient:
params = {} params = {}
try: try:
self._inject_auth_info(headers, params) self._inject_auth_info(headers, params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params) response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status() response.raise_for_status()
return response.text return response.text
except requests.RequestException as e: except httpx.RequestError as e:
return f"Request to Nacos failed: {e}" return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None: def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
@ -78,7 +78,7 @@ class NacosHttpClient:
params = {"username": self.username, "password": self.password} params = {"username": self.username, "password": self.password}
url = "http://" + self.server + "/nacos/v1/auth/login" url = "http://" + self.server + "/nacos/v1/auth/login"
try: try:
resp = requests.request("POST", url, headers=None, params=params) resp = httpx.request("POST", url, headers=None, params=params)
resp.raise_for_status() resp.raise_for_status()
response_data = resp.json() response_data = resp.json()
self.token = response_data.get("accessToken") self.token = response_data.get("accessToken")

View File

@ -1,4 +1,5 @@
from configs import dify_config from configs import dify_config
from libs.collection_utils import convert_to_lower_and_upper_set
HIDDEN_VALUE = "[__HIDDEN__]" HIDDEN_VALUE = "[__HIDDEN__]"
UNKNOWN_VALUE = "[__UNKNOWN__]" UNKNOWN_VALUE = "[__UNKNOWN__]"
@ -6,24 +7,39 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3 DEFAULT_FILE_NUMBER_LIMITS = 3
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"] VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"] AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
_doc_extensions: set[str]
_doc_extensions: list[str]
if dify_config.ETL_TYPE == "Unstructured": if dify_config.ETL_TYPE == "Unstructured":
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] _doc_extensions = {
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) "txt",
"markdown",
"md",
"mdx",
"pdf",
"html",
"htm",
"xlsx",
"xls",
"vtt",
"properties",
"doc",
"docx",
"csv",
"eml",
"msg",
"pptx",
"xml",
"epub",
}
if dify_config.UNSTRUCTURED_API_URL: if dify_config.UNSTRUCTURED_API_URL:
_doc_extensions.append("ppt") _doc_extensions.add("ppt")
else: else:
_doc_extensions = [ _doc_extensions = {
"txt", "txt",
"markdown", "markdown",
"md", "md",
@ -37,5 +53,5 @@ else:
"csv", "csv",
"vtt", "vtt",
"properties", "properties",
] }
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions] DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)

View File

@ -1,31 +1,10 @@
from importlib import import_module
from flask import Blueprint from flask import Blueprint
from flask_restx import Namespace from flask_restx import Namespace
from libs.external_api import ExternalApi from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
from .explore.audio import ChatAudioApi, ChatTextApi
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
from .explore.conversation import (
ConversationApi,
ConversationListApi,
ConversationPinApi,
ConversationRenameApi,
ConversationUnPinApi,
)
from .explore.message import (
MessageFeedbackApi,
MessageListApi,
MessageMoreLikeThisApi,
MessageSuggestedQuestionApi,
)
from .explore.workflow import (
InstalledAppWorkflowRunApi,
InstalledAppWorkflowTaskStopApi,
)
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
bp = Blueprint("console", __name__, url_prefix="/console/api") bp = Blueprint("console", __name__, url_prefix="/console/api")
api = ExternalApi( api = ExternalApi(
@ -35,23 +14,23 @@ api = ExternalApi(
description="Console management APIs for app configuration, monitoring, and administration", description="Console management APIs for app configuration, monitoring, and administration",
) )
# Create namespace
console_ns = Namespace("console", description="Console management API operations", path="/") console_ns = Namespace("console", description="Console management API operations", path="/")
# File RESOURCE_MODULES = (
api.add_resource(FileApi, "/files/upload") "controllers.console.app.app_import",
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview") "controllers.console.explore.audio",
api.add_resource(FileSupportTypeApi, "/files/support-type") "controllers.console.explore.completion",
"controllers.console.explore.conversation",
"controllers.console.explore.message",
"controllers.console.explore.workflow",
"controllers.console.files",
"controllers.console.remote_files",
)
# Remote files for module_name in RESOURCE_MODULES:
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>") import_module(module_name)
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import App
api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
# Ensure resource modules are imported so route decorators are evaluated.
# Import other controllers # Import other controllers
from . import ( from . import (
admin, admin,
@ -150,77 +129,6 @@ from .workspace import (
workspace, workspace,
) )
# Explore Audio
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
# Explore Completion
api.add_resource(
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
)
api.add_resource(
CompletionStopApi,
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
endpoint="installed_app_stop_completion",
)
api.add_resource(
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
)
api.add_resource(
ChatStopApi,
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
endpoint="installed_app_stop_chat_completion",
)
# Explore Conversation
api.add_resource(
ConversationRenameApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
endpoint="installed_app_conversation_rename",
)
api.add_resource(
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
)
api.add_resource(
ConversationApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
endpoint="installed_app_conversation",
)
api.add_resource(
ConversationPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
endpoint="installed_app_conversation_pin",
)
api.add_resource(
ConversationUnPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
endpoint="installed_app_conversation_unpin",
)
# Explore Message
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
api.add_resource(
MessageFeedbackApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
endpoint="installed_app_message_feedback",
)
api.add_resource(
MessageMoreLikeThisApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
endpoint="installed_app_more_like_this",
)
api.add_resource(
MessageSuggestedQuestionApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="installed_app_suggested_question",
)
# Explore Workflow
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
api.add_resource(
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
)
api.add_namespace(console_ns) api.add_namespace(console_ns)
__all__ = [ __all__ = [

View File

@ -19,6 +19,7 @@ from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import login_required from libs.login import login_required
from libs.validators import validate_description_length
from models import Account, App from models import Account, App
from services.app_dsl_service import AppDslService, ImportMode from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService from services.app_service import AppService
@ -28,12 +29,6 @@ from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/apps") @console_ns.route("/apps")
class AppListApi(Resource): class AppListApi(Resource):
@api.doc("list_apps") @api.doc("list_apps")
@ -138,7 +133,7 @@ class AppListApi(Resource):
"""Create app""" """Create app"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json") parser.add_argument("icon", type=str, location="json")
@ -219,7 +214,7 @@ class AppApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json") parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json") parser.add_argument("icon_background", type=str, location="json")
@ -297,7 +292,7 @@ class AppCopyApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, location="json") parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json") parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json") parser.add_argument("icon_background", type=str, location="json")
@ -309,7 +304,7 @@ class AppCopyApi(Resource):
account = cast(Account, current_user) account = cast(Account, current_user)
result = import_service.import_app( result = import_service.import_app(
account=account, account=account,
import_mode=ImportMode.YAML_CONTENT.value, import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content, yaml_content=yaml_content,
name=args.get("name"), name=args.get("name"),
description=args.get("description"), description=args.get("description"),

View File

@ -20,7 +20,10 @@ from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from .. import console_ns
@console_ns.route("/apps/imports")
class AppImportApi(Resource): class AppImportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -67,13 +70,14 @@ class AppImportApi(Resource):
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result # Return appropriate status code based on result
status = result.status status = result.status
if status == ImportStatus.FAILED.value: if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value: elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202 return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@console_ns.route("/apps/imports/<string:import_id>/confirm")
class AppImportConfirmApi(Resource): class AppImportConfirmApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -93,11 +97,12 @@ class AppImportConfirmApi(Resource):
session.commit() session.commit()
# Return appropriate status code based on result # Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value: if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
class AppImportCheckDependenciesApi(Resource): class AppImportCheckDependenciesApi(Resource):
@setup_required @setup_required
@login_required @login_required

View File

@ -1,6 +1,7 @@
from datetime import datetime from datetime import datetime
import pytz # pip install pytz import pytz # pip install pytz
import sqlalchemy as sa
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
@ -70,7 +71,7 @@ class CompletionConversationApi(Resource):
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
query = db.select(Conversation).where( query = sa.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
) )
@ -236,7 +237,7 @@ class ChatConversationApi(Resource):
.subquery() .subquery()
) )
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args["keyword"]: if args["keyword"]:
keyword_filter = f"%{args['keyword']}%" keyword_filter = f"%{args['keyword']}%"
@ -308,7 +309,7 @@ class ChatConversationApi(Resource):
) )
if app_model.mode == AppMode.ADVANCED_CHAT: if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
match args["sort_by"]: match args["sort_by"]:
case "created_at": case "created_at":

View File

@ -14,6 +14,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated from events.app_event import app_model_config_was_updated
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import login_required
from models.account import Account from models.account import Account
from models.model import AppMode, AppModelConfig from models.model import AppMode, AppModelConfig
@ -90,7 +91,7 @@ class ModelConfigResource(Resource):
if not isinstance(tool, dict) or len(tool.keys()) <= 3: if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue continue
agent_tool_entity = AgentToolEntity(**tool) agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool # get tool
try: try:
tool_runtime = ToolManager.get_agent_tool_runtime( tool_runtime = ToolManager.get_agent_tool_runtime(
@ -124,7 +125,7 @@ class ModelConfigResource(Resource):
# encrypt agent tool parameters if it's secret-input # encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get("tools") or []: for tool in agent_mode.get("tools") or []:
agent_tool_entity = AgentToolEntity(**tool) agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool # get tool
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
@ -172,6 +173,8 @@ class ModelConfigResource(Resource):
db.session.flush() db.session.flush()
app_model.app_model_config_id = new_app_model_config.id app_model.app_model_config_id = new_app_model_config.id
app_model.updated_by = current_user.id
app_model.updated_at = naive_utc_now()
db.session.commit() db.session.commit()
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)

View File

@ -52,7 +52,7 @@ FROM
WHERE WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -127,7 +127,7 @@ class DailyConversationStatistic(Resource):
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"), sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
) )
.select_from(Message) .select_from(Message)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value) .where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
) )
if args["start"]: if args["start"]:
@ -190,7 +190,7 @@ FROM
WHERE WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -263,7 +263,7 @@ FROM
WHERE WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -345,7 +345,7 @@ FROM
WHERE WHERE
c.app_id = :app_id c.app_id = :app_id
AND m.invoke_from != :invoke_from""" AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -432,7 +432,7 @@ LEFT JOIN
WHERE WHERE
m.app_id = :app_id m.app_id = :app_id
AND m.invoke_from != :invoke_from""" AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -509,7 +509,7 @@ FROM
WHERE WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -584,7 +584,7 @@ FROM
WHERE WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc

View File

@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from configs import dify_config
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -26,6 +25,7 @@ from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models import App from models import App
@ -675,8 +675,12 @@ class PublishedWorkflowApi(Resource):
marked_comment=args.marked_comment or "", marked_comment=args.marked_comment or "",
) )
app_model.workflow_id = workflow.id # Update app_model within the same session to ensure atomicity
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id app_model_in_session = session.get(App, app_model.id)
if app_model_in_session:
app_model_in_session.workflow_id = workflow.id
app_model_in_session.updated_by = current_user.id
app_model_in_session.updated_at = naive_utc_now()
workflow_created_at = TimestampField().format(workflow.created_at) workflow_created_at = TimestampField().format(workflow.created_at)
@ -797,24 +801,6 @@ class ConvertToWorkflowApi(Resource):
} }
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/config")
class WorkflowConfigApi(Resource):
"""Resource for workflow configuration."""
@api.doc("get_workflow_config")
@api.doc(description="Get workflow configuration")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Workflow configuration retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}
@console_ns.route("/apps/<uuid:app_id>/workflows") @console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource): class PublishedAllWorkflowApi(Resource):
@api.doc("get_all_published_workflows") @api.doc("get_all_published_workflows")

View File

@ -47,7 +47,7 @@ WHERE
arg_dict = { arg_dict = {
"tz": account.timezone, "tz": account.timezone,
"app_id": app_model.id, "app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
} }
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -115,7 +115,7 @@ WHERE
arg_dict = { arg_dict = {
"tz": account.timezone, "tz": account.timezone,
"app_id": app_model.id, "app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
} }
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -183,7 +183,7 @@ WHERE
arg_dict = { arg_dict = {
"tz": account.timezone, "tz": account.timezone,
"app_id": app_model.id, "app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
} }
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -269,7 +269,7 @@ GROUP BY
arg_dict = { arg_dict = {
"tz": account.timezone, "tz": account.timezone,
"app_id": app_model.id, "app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
} }
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)

View File

@ -103,7 +103,7 @@ class ActivateApi(Resource):
account.interface_language = args["interface_language"] account.interface_language = args["interface_language"]
account.timezone = args["timezone"] account.timezone = args["timezone"]
account.interface_theme = "light" account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now() account.initialized_at = naive_utc_now()
db.session.commit() db.session.commit()

View File

@ -2,7 +2,7 @@ from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required from libs.login import login_required
from services.auth.api_key_auth_service import ApiKeyAuthService from services.auth.api_key_auth_service import ApiKeyAuthService
@ -10,6 +10,7 @@ from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required from ..wraps import account_initialization_required, setup_required
@console_ns.route("/api-key-auth/data-source")
class ApiKeyAuthDataSource(Resource): class ApiKeyAuthDataSource(Resource):
@setup_required @setup_required
@login_required @login_required
@ -33,6 +34,7 @@ class ApiKeyAuthDataSource(Resource):
return {"sources": []} return {"sources": []}
@console_ns.route("/api-key-auth/data-source/binding")
class ApiKeyAuthDataSourceBinding(Resource): class ApiKeyAuthDataSourceBinding(Resource):
@setup_required @setup_required
@login_required @login_required
@ -54,6 +56,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/api-key-auth/data-source/<uuid:binding_id>")
class ApiKeyAuthDataSourceBindingDelete(Resource): class ApiKeyAuthDataSourceBindingDelete(Resource):
@setup_required @setup_required
@login_required @login_required
@ -66,8 +69,3 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
return {"result": "success"}, 204 return {"result": "success"}, 204
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding")
api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>")

View File

@ -1,6 +1,6 @@
import logging import logging
import requests import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource, fields
@ -119,7 +119,7 @@ class OAuthDataSourceBinding(Resource):
return {"error": "Invalid code"}, 400 return {"error": "Invalid code"}, 400
try: try:
oauth_provider.get_access_token(code) oauth_provider.get_access_token(code)
except requests.HTTPError as e: except httpx.HTTPStatusError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )
@ -152,7 +152,7 @@ class OAuthDataSourceSync(Resource):
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
try: try:
oauth_provider.sync_data_source(binding_id) oauth_provider.sync_data_source(binding_id)
except requests.HTTPError as e: except httpx.HTTPStatusError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )

View File

@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
EmailAlreadyInUseError, EmailAlreadyInUseError,
EmailCodeError, EmailCodeError,
@ -25,6 +25,7 @@ from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.account import AccountNotFoundError, AccountRegisterError
@console_ns.route("/email-register/send-email")
class EmailRegisterSendEmailApi(Resource): class EmailRegisterSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@ -52,6 +53,7 @@ class EmailRegisterSendEmailApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
@console_ns.route("/email-register/validity")
class EmailRegisterCheckApi(Resource): class EmailRegisterCheckApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@ -92,6 +94,7 @@ class EmailRegisterCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/email-register")
class EmailRegisterResetApi(Resource): class EmailRegisterResetApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@ -148,8 +151,3 @@ class EmailRegisterResetApi(Resource):
raise AccountInFreezeError() raise AccountInFreezeError()
return account return account
api.add_resource(EmailRegisterSendEmailApi, "/email-register/send-email")
api.add_resource(EmailRegisterCheckApi, "/email-register/validity")
api.add_resource(EmailRegisterResetApi, "/email-register")

View File

@ -221,8 +221,3 @@ class ForgotPasswordResetApi(Resource):
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = tenant
tenant_was_created.send(tenant) tenant_was_created.send(tenant)
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")

View File

@ -7,7 +7,7 @@ from flask_restx import Resource, reqparse
import services import services
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
AuthenticationFailedError, AuthenticationFailedError,
EmailCodeError, EmailCodeError,
@ -34,6 +34,7 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces
from services.feature_service import FeatureService from services.feature_service import FeatureService
@console_ns.route("/login")
class LoginApi(Resource): class LoginApi(Resource):
"""Resource for user login.""" """Resource for user login."""
@ -91,6 +92,7 @@ class LoginApi(Resource):
return {"result": "success", "data": token_pair.model_dump()} return {"result": "success", "data": token_pair.model_dump()}
@console_ns.route("/logout")
class LogoutApi(Resource): class LogoutApi(Resource):
@setup_required @setup_required
def get(self): def get(self):
@ -102,6 +104,7 @@ class LogoutApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/reset-password")
class ResetPasswordSendEmailApi(Resource): class ResetPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@ -130,6 +133,7 @@ class ResetPasswordSendEmailApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
@console_ns.route("/email-code-login")
class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginSendEmailApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
@ -162,6 +166,7 @@ class EmailCodeLoginSendEmailApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
@console_ns.route("/email-code-login/validity")
class EmailCodeLoginApi(Resource): class EmailCodeLoginApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
@ -218,6 +223,7 @@ class EmailCodeLoginApi(Resource):
return {"result": "success", "data": token_pair.model_dump()} return {"result": "success", "data": token_pair.model_dump()}
@console_ns.route("/refresh-token")
class RefreshTokenApi(Resource): class RefreshTokenApi(Resource):
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -229,11 +235,3 @@ class RefreshTokenApi(Resource):
return {"result": "success", "data": new_token_pair.model_dump()} return {"result": "success", "data": new_token_pair.model_dump()}
except Exception as e: except Exception as e:
return {"result": "fail", "data": str(e)}, 401 return {"result": "fail", "data": str(e)}, 401
api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
api.add_resource(ResetPasswordSendEmailApi, "/reset-password")
api.add_resource(RefreshTokenApi, "/refresh-token")

View File

@ -1,6 +1,6 @@
import logging import logging
import requests import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restx import Resource from flask_restx import Resource
from sqlalchemy import select from sqlalchemy import select
@ -101,8 +101,10 @@ class OAuthCallback(Resource):
try: try:
token = oauth_provider.get_access_token(code) token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token) user_info = oauth_provider.get_user_info(token)
except requests.RequestException as e: except httpx.RequestError as e:
error_text = e.response.text if e.response else str(e) error_text = str(e)
if isinstance(e, httpx.HTTPStatusError):
error_text = e.response.text
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400 return {"error": "OAuth process failed"}, 400
@ -128,11 +130,11 @@ class OAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}") return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
# Check account status # Check account status
if account.status == AccountStatus.BANNED.value: if account.status == AccountStatus.BANNED:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.") return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING:
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now() account.initialized_at = naive_utc_now()
db.session.commit() db.session.commit()

View File

@ -14,7 +14,7 @@ from models.account import Account
from models.model import OAuthProviderApp from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
from .. import api from .. import console_ns
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
@ -86,6 +86,7 @@ def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProvid
return decorated return decorated
@console_ns.route("/oauth/provider")
class OAuthServerAppApi(Resource): class OAuthServerAppApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
@ -108,6 +109,7 @@ class OAuthServerAppApi(Resource):
) )
@console_ns.route("/oauth/provider/authorize")
class OAuthServerUserAuthorizeApi(Resource): class OAuthServerUserAuthorizeApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -125,6 +127,7 @@ class OAuthServerUserAuthorizeApi(Resource):
) )
@console_ns.route("/oauth/provider/token")
class OAuthServerUserTokenApi(Resource): class OAuthServerUserTokenApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
@ -180,6 +183,7 @@ class OAuthServerUserTokenApi(Resource):
) )
@console_ns.route("/oauth/provider/account")
class OAuthServerUserAccountApi(Resource): class OAuthServerUserAccountApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
@ -194,9 +198,3 @@ class OAuthServerUserAccountApi(Resource):
"timezone": account.timezone, "timezone": account.timezone,
} }
) )
api.add_resource(OAuthServerAppApi, "/oauth/provider")
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")

View File

@ -1,12 +1,13 @@
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models.model import Account from models.model import Account
from services.billing_service import BillingService from services.billing_service import BillingService
@console_ns.route("/billing/subscription")
class Subscription(Resource): class Subscription(Resource):
@setup_required @setup_required
@login_required @login_required
@ -26,6 +27,7 @@ class Subscription(Resource):
) )
@console_ns.route("/billing/invoices")
class Invoices(Resource): class Invoices(Resource):
@setup_required @setup_required
@login_required @login_required
@ -36,7 +38,3 @@ class Invoices(Resource):
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None assert current_user.current_tenant_id is not None
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
api.add_resource(Subscription, "/billing/subscription")
api.add_resource(Invoices, "/billing/invoices")

View File

@ -6,10 +6,11 @@ from libs.helper import extract_remote_ip
from libs.login import login_required from libs.login import login_required
from services.billing_service import BillingService from services.billing_service import BillingService
from .. import api from .. import console_ns
from ..wraps import account_initialization_required, only_edition_cloud, setup_required from ..wraps import account_initialization_required, only_edition_cloud, setup_required
@console_ns.route("/compliance/download")
class ComplianceApi(Resource): class ComplianceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -30,6 +31,3 @@ class ComplianceApi(Resource):
ip=ip_address, ip=ip_address,
device_info=device_info, device_info=device_info,
) )
api.add_resource(ComplianceApi, "/compliance/download")

View File

@ -9,13 +9,13 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
@ -27,6 +27,10 @@ from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task from tasks.document_indexing_sync_task import document_indexing_sync_task
@console_ns.route(
"/data-source/integrates",
"/data-source/integrates/<uuid:binding_id>/<string:action>",
)
class DataSourceApi(Resource): class DataSourceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -109,6 +113,7 @@ class DataSourceApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/notion/pre-import/pages")
class DataSourceNotionListApi(Resource): class DataSourceNotionListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -196,6 +201,10 @@ class DataSourceNotionListApi(Resource):
return {"notion_info": {**workspace_info, "pages": pages}}, 200 return {"notion_info": {**workspace_info, "pages": pages}}, 200
@console_ns.route(
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
class DataSourceNotionApi(Resource): class DataSourceNotionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -247,14 +256,16 @@ class DataSourceNotionApi(Resource):
credential_id = notion_info.get("credential_id") credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]: for page in notion_info["pages"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION,
notion_info={ notion_info=NotionInfo.model_validate(
"credential_id": credential_id, {
"notion_workspace_id": workspace_id, "credential_id": credential_id,
"notion_obj_id": page["page_id"], "notion_workspace_id": workspace_id,
"notion_page_type": page["type"], "notion_obj_id": page["page_id"],
"tenant_id": current_user.current_tenant_id, "notion_page_type": page["type"],
}, "tenant_id": current_user.current_tenant_id,
}
),
document_model=args["doc_form"], document_model=args["doc_form"],
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@ -269,6 +280,7 @@ class DataSourceNotionApi(Resource):
return response.model_dump(), 200 return response.model_dump(), 200
@console_ns.route("/datasets/<uuid:dataset_id>/notion/sync")
class DataSourceNotionDatasetSyncApi(Resource): class DataSourceNotionDatasetSyncApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -285,6 +297,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync")
class DataSourceNotionDocumentSyncApi(Resource): class DataSourceNotionDocumentSyncApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -301,16 +314,3 @@ class DataSourceNotionDocumentSyncApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
document_indexing_sync_task.delay(dataset_id_str, document_id_str) document_indexing_sync_task.delay(dataset_id_str, document_id_str)
return {"result": "success"}, 200 return {"result": "success"}, 200
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
api.add_resource(
DataSourceNotionApi,
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
api.add_resource(
DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
)

View File

@ -1,4 +1,5 @@
import flask_restx from typing import Any, cast
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
@ -23,31 +24,27 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import related_app_list from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields from fields.document_fields import document_status_fields
from libs.login import login_required from libs.login import login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
def _validate_name(name): def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40: if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.") raise ValueError("Name must be between 1 to 40 characters.")
return name return name
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/datasets") @console_ns.route("/datasets")
class DatasetListApi(Resource): class DatasetListApi(Resource):
@api.doc("get_datasets") @api.doc("get_datasets")
@ -92,7 +89,7 @@ class DatasetListApi(Resource):
for embedding_model in embedding_models: for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields) data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
for item in data: for item in data:
# convert embedding_model_provider to plugin standard format # convert embedding_model_provider to plugin standard format
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
@ -147,7 +144,7 @@ class DatasetListApi(Resource):
) )
parser.add_argument( parser.add_argument(
"description", "description",
type=_validate_description_length, type=validate_description_length,
nullable=True, nullable=True,
required=False, required=False,
default="", default="",
@ -192,7 +189,7 @@ class DatasetListApi(Resource):
name=args["name"], name=args["name"],
description=args["description"], description=args["description"],
indexing_technique=args["indexing_technique"], indexing_technique=args["indexing_technique"],
account=current_user, account=cast(Account, current_user),
permission=DatasetPermissionEnum.ONLY_ME, permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"], provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"], external_knowledge_api_id=args["external_knowledge_api_id"],
@ -224,7 +221,7 @@ class DatasetApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
if dataset.embedding_model_provider: if dataset.embedding_model_provider:
provider_id = ModelProviderID(dataset.embedding_model_provider) provider_id = ModelProviderID(dataset.embedding_model_provider)
@ -288,7 +285,7 @@ class DatasetApi(Resource):
help="type is required. Name must be between 1 to 40 characters.", help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name, type=_validate_name,
) )
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) parser.add_argument("description", location="json", store_missing=False, type=validate_description_length)
parser.add_argument( parser.add_argument(
"indexing_technique", "indexing_technique",
type=str, type=str,
@ -369,7 +366,7 @@ class DatasetApi(Resource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields) result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members": if data.get("partial_member_list") and data.get("permission") == "partial_members":
@ -503,7 +500,7 @@ class DatasetIndexingEstimateApi(Resource):
if file_details: if file_details:
for file_detail in file_details: for file_detail in file_details:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, datasource_type=DatasourceType.FILE,
upload_file=file_detail, upload_file=file_detail,
document_model=args["doc_form"], document_model=args["doc_form"],
) )
@ -515,14 +512,16 @@ class DatasetIndexingEstimateApi(Resource):
credential_id = notion_info.get("credential_id") credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]: for page in notion_info["pages"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION,
notion_info={ notion_info=NotionInfo.model_validate(
"credential_id": credential_id, {
"notion_workspace_id": workspace_id, "credential_id": credential_id,
"notion_obj_id": page["page_id"], "notion_workspace_id": workspace_id,
"notion_page_type": page["type"], "notion_obj_id": page["page_id"],
"tenant_id": current_user.current_tenant_id, "notion_page_type": page["type"],
}, "tenant_id": current_user.current_tenant_id,
}
),
document_model=args["doc_form"], document_model=args["doc_form"],
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@ -530,15 +529,17 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"] website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]: for url in website_info_list["urls"]:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value, datasource_type=DatasourceType.WEBSITE,
website_info={ website_info=WebsiteInfo.model_validate(
"provider": website_info_list["provider"], {
"job_id": website_info_list["job_id"], "provider": website_info_list["provider"],
"url": url, "job_id": website_info_list["job_id"],
"tenant_id": current_user.current_tenant_id, "url": url,
"mode": "crawl", "tenant_id": current_user.current_tenant_id,
"only_main_content": website_info_list["only_main_content"], "mode": "crawl",
}, "only_main_content": website_info_list["only_main_content"],
}
),
document_model=args["doc_form"], document_model=args["doc_form"],
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@ -688,7 +689,7 @@ class DatasetApiKeyApi(Resource):
) )
if current_key_count >= self.max_keys: if current_key_count >= self.max_keys:
flask_restx.abort( api.abort(
400, 400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.", message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded", code="max_keys_exceeded",
@ -733,7 +734,7 @@ class DatasetApiDeleteApi(Resource):
) )
if key is None: if key is None:
flask_restx.abort(404, message="API key not found") api.abort(404, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit() db.session.commit()
@ -785,7 +786,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.VIKINGDB | VectorType.VIKINGDB
| VectorType.UPSTASH | VectorType.UPSTASH
): ):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
case ( case (
VectorType.QDRANT VectorType.QDRANT
| VectorType.WEAVIATE | VectorType.WEAVIATE
@ -809,12 +810,13 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.MATRIXONE | VectorType.MATRIXONE
| VectorType.CLICKZETTA | VectorType.CLICKZETTA
| VectorType.BAIDU | VectorType.BAIDU
| VectorType.ALIBABACLOUD_MYSQL
): ):
return { return {
"retrieval_method": [ "retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH.value, RetrievalMethod.HYBRID_SEARCH,
] ]
} }
case _: case _:
@ -841,7 +843,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.VIKINGDB | VectorType.VIKINGDB
| VectorType.UPSTASH | VectorType.UPSTASH
): ):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
case ( case (
VectorType.QDRANT VectorType.QDRANT
| VectorType.WEAVIATE | VectorType.WEAVIATE
@ -863,12 +865,13 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.MATRIXONE | VectorType.MATRIXONE
| VectorType.CLICKZETTA | VectorType.CLICKZETTA
| VectorType.BAIDU | VectorType.BAIDU
| VectorType.ALIBABACLOUD_MYSQL
): ):
return { return {
"retrieval_method": [ "retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH.value, RetrievalMethod.HYBRID_SEARCH,
] ]
} }
case _: case _:

View File

@ -4,6 +4,7 @@ from argparse import ArgumentTypeError
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal, cast from typing import Literal, cast
import sqlalchemy as sa
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
@ -43,7 +44,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from extensions.ext_database import db from extensions.ext_database import db
from fields.document_fields import ( from fields.document_fields import (
dataset_and_document_fields, dataset_and_document_fields,
@ -54,6 +55,7 @@ from fields.document_fields import (
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DocumentPipelineExecutionLog from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
@ -211,13 +213,13 @@ class DatasetDocumentListApi(Resource):
if sort == "hit_count": if sort == "hit_count":
sub_query = ( sub_query = (
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
.group_by(DocumentSegment.document_id) .group_by(DocumentSegment.document_id)
.subquery() .subquery()
) )
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position), sort_logic(Document.position),
) )
elif sort == "created_at": elif sort == "created_at":
@ -303,7 +305,7 @@ class DatasetDocumentListApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json" "doc_language", type=str, default="English", required=False, nullable=False, location="json"
) )
args = parser.parse_args() args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
if not dataset.indexing_technique and not knowledge_config.indexing_technique: if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
@ -393,7 +395,7 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.") raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
@ -417,7 +419,9 @@ class DatasetInitApi(Resource):
try: try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id( dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user tenant_id=current_user.current_tenant_id,
knowledge_config=knowledge_config,
account=cast(Account, current_user),
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -451,7 +455,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule data_process_rule = document.dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict() data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
@ -471,7 +475,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.") raise NotFound("File not found.")
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form
) )
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
@ -513,7 +517,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not documents: if not documents:
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200 return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
data_process_rule = documents[0].dataset_process_rule data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict() data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
extract_settings = [] extract_settings = []
for document in documents: for document in documents:
if document.indexing_status in {"completed", "error"}: if document.indexing_status in {"completed", "error"}:
@ -534,7 +538,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.") raise NotFound("File not found.")
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@ -542,14 +546,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info: if not data_source_info:
continue continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION,
notion_info={ notion_info=NotionInfo.model_validate(
"credential_id": data_source_info["credential_id"], {
"notion_workspace_id": data_source_info["notion_workspace_id"], "credential_id": data_source_info["credential_id"],
"notion_obj_id": data_source_info["notion_page_id"], "notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_page_type": data_source_info["type"], "notion_obj_id": data_source_info["notion_page_id"],
"tenant_id": current_user.current_tenant_id, "notion_page_type": data_source_info["type"],
}, "tenant_id": current_user.current_tenant_id,
}
),
document_model=document.doc_form, document_model=document.doc_form,
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@ -557,15 +563,17 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info: if not data_source_info:
continue continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value, datasource_type=DatasourceType.WEBSITE,
website_info={ website_info=WebsiteInfo.model_validate(
"provider": data_source_info["provider"], {
"job_id": data_source_info["job_id"], "provider": data_source_info["provider"],
"url": data_source_info["url"], "job_id": data_source_info["job_id"],
"tenant_id": current_user.current_tenant_id, "url": data_source_info["url"],
"mode": data_source_info["mode"], "tenant_id": current_user.current_tenant_id,
"only_main_content": data_source_info["only_main_content"], "mode": data_source_info["mode"],
}, "only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form, document_model=document.doc_form,
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
@ -752,7 +760,7 @@ class DocumentApi(DocumentResource):
} }
else: else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict data_source_info = document.data_source_detail_dict
response = { response = {
"id": document.id, "id": document.id,
@ -1072,7 +1080,9 @@ class DocumentRenameApi(DocumentResource):
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_operator_permission(current_user, dataset) if not dataset:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -1113,6 +1123,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log")
class DocumentPipelineExecutionLogApi(DocumentResource): class DocumentPipelineExecutionLogApi(DocumentResource):
@setup_required @setup_required
@login_required @login_required
@ -1146,29 +1157,3 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
"input_data": log.input_data, "input_data": log.input_data,
"datasource_node_id": log.datasource_node_id, "datasource_node_id": log.datasource_node_id,
}, 200 }, 200
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DatasetInitApi, "/datasets/init")
api.add_resource(
DocumentIndexingEstimateApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate"
)
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
)
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
api.add_resource(
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
)

View File

@ -7,7 +7,7 @@ from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.console import api from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import ( from controllers.console.datasets.error import (
ChildChunkDeleteIndexError, ChildChunkDeleteIndexError,
@ -37,6 +37,7 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
class DatasetDocumentSegmentListApi(Resource): class DatasetDocumentSegmentListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -139,6 +140,7 @@ class DatasetDocumentSegmentListApi(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
class DatasetDocumentSegmentApi(Resource): class DatasetDocumentSegmentApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -193,6 +195,7 @@ class DatasetDocumentSegmentApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
class DatasetDocumentSegmentAddApi(Resource): class DatasetDocumentSegmentAddApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -244,6 +247,7 @@ class DatasetDocumentSegmentAddApi(Resource):
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
class DatasetDocumentSegmentUpdateApi(Resource): class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -305,7 +309,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
) )
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset) segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required @setup_required
@ -345,6 +349,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
"/datasets/batch_import_status/<uuid:job_id>",
)
class DatasetDocumentSegmentBatchImportApi(Resource): class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -384,7 +392,12 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
# send batch add segments task # send batch add segments task
redis_client.setnx(indexing_cache_key, "waiting") redis_client.setnx(indexing_cache_key, "waiting")
batch_create_segment_to_index_task.delay( batch_create_segment_to_index_task.delay(
str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id str(job_id),
upload_file_id,
dataset_id,
document_id,
current_user.current_tenant_id,
current_user.id,
) )
except Exception as e: except Exception as e:
return {"error": str(e)}, 500 return {"error": str(e)}, 500
@ -393,7 +406,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, job_id): def get(self, job_id=None, dataset_id=None, document_id=None):
if job_id is None:
raise NotFound("The job does not exist.")
job_id = str(job_id) job_id = str(job_id)
indexing_cache_key = f"segment_batch_import_{job_id}" indexing_cache_key = f"segment_batch_import_{job_id}"
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
@ -403,6 +418,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
return {"job_id": job_id, "job_status": cache_result.decode()}, 200 return {"job_id": job_id, "job_status": cache_result.decode()}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks")
class ChildChunkAddApi(Resource): class ChildChunkAddApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -457,7 +473,8 @@ class ChildChunkAddApi(Resource):
parser.add_argument("content", type=str, required=True, nullable=False, location="json") parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) content = args["content"]
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@ -546,13 +563,17 @@ class ChildChunkAddApi(Resource):
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")] chunks_data = args["chunks"]
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200 return {"data": marshal(child_chunks, child_chunk_fields)}, 200
@console_ns.route(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>"
)
class ChildChunkUpdateApi(Resource): class ChildChunkUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -660,33 +681,8 @@ class ChildChunkUpdateApi(Resource):
parser.add_argument("content", type=str, required=True, nullable=False, location="json") parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
child_chunk = SegmentService.update_child_chunk( content = args["content"]
args.get("content"), child_chunk, segment, document, dataset child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
api.add_resource(
DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
)
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
api.add_resource(
DatasetDocumentSegmentUpdateApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
)
api.add_resource(
DatasetDocumentSegmentBatchImportApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
"/datasets/batch_import_status/<uuid:job_id>",
)
api.add_resource(
ChildChunkAddApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
)
api.add_resource(
ChildChunkUpdateApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
)

View File

@ -1,3 +1,5 @@
from typing import cast
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields, marshal, reqparse from flask_restx import Resource, fields, marshal, reqparse
@ -9,13 +11,14 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required from libs.login import login_required
from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService from services.knowledge_service import ExternalDatasetTestService
def _validate_name(name): def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 100: if not name or len(name) < 1 or len(name) > 100:
raise ValueError("Name must be between 1 to 100 characters.") raise ValueError("Name must be between 1 to 100 characters.")
return name return name
@ -274,7 +277,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
response = HitTestingService.external_retrieve( response = HitTestingService.external_retrieve(
dataset=dataset, dataset=dataset,
query=args["query"], query=args["query"],
account=current_user, account=cast(Account, current_user),
external_retrieval_model=args["external_retrieval_model"], external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"], metadata_filtering_conditions=args["metadata_filtering_conditions"],
) )

View File

@ -1,10 +1,11 @@
import logging import logging
from typing import cast
from flask_login import current_user from flask_login import current_user
from flask_restx import marshal, reqparse from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services.dataset_service import services
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
@ -20,6 +21,7 @@ from core.errors.error import (
) )
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields from fields.hit_testing_fields import hit_testing_record_fields
from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
@ -59,7 +61,7 @@ class DatasetsHitTestingBase:
response = HitTestingService.retrieve( response = HitTestingService.retrieve(
dataset=dataset, dataset=dataset,
query=args["query"], query=args["query"],
account=current_user, account=cast(Account, current_user),
retrieval_model=args["retrieval_model"], retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"], external_retrieval_model=args["external_retrieval_model"],
limit=10, limit=10,

View File

@ -4,7 +4,7 @@ from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields from fields.dataset_fields import dataset_metadata_fields
from libs.login import login_required from libs.login import login_required
@ -16,6 +16,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateApi(Resource): class DatasetMetadataCreateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -27,7 +28,7 @@ class DatasetMetadataCreateApi(Resource):
parser.add_argument("type", type=str, required=True, nullable=False, location="json") parser.add_argument("type", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
metadata_args = MetadataArgs(**args) metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -50,6 +51,7 @@ class DatasetMetadataCreateApi(Resource):
return MetadataService.get_dataset_metadatas(dataset), 200 return MetadataService.get_dataset_metadatas(dataset), 200
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
class DatasetMetadataApi(Resource): class DatasetMetadataApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -60,6 +62,7 @@ class DatasetMetadataApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
name = args["name"]
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id) metadata_id_str = str(metadata_id)
@ -68,7 +71,7 @@ class DatasetMetadataApi(Resource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
return metadata, 200 return metadata, 200
@setup_required @setup_required
@ -87,6 +90,7 @@ class DatasetMetadataApi(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route("/datasets/metadata/built-in")
class DatasetMetadataBuiltInFieldApi(Resource): class DatasetMetadataBuiltInFieldApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -97,6 +101,7 @@ class DatasetMetadataBuiltInFieldApi(Resource):
return {"fields": built_in_fields}, 200 return {"fields": built_in_fields}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
class DatasetMetadataBuiltInFieldActionApi(Resource): class DatasetMetadataBuiltInFieldActionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -116,6 +121,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
class DocumentMetadataEditApi(Resource): class DocumentMetadataEditApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -131,15 +137,8 @@ class DocumentMetadataEditApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
metadata_args = MetadataOperationData(**args) metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args) MetadataService.update_documents_metadata(dataset, metadata_args)
return {"result": "success"}, 200 return {"result": "success"}, 200
api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata")

View File

@ -1,16 +1,16 @@
from fastapi.encoders import jsonable_encoder
from flask import make_response, redirect, request from flask import make_response, redirect, request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
setup_required, setup_required,
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen from libs.helper import StrLen
from libs.login import login_required from libs.login import login_required
@ -19,6 +19,7 @@ from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService from services.plugin.oauth_service import OAuthProxyService
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
class DatasourcePluginOAuthAuthorizationUrl(Resource): class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required @setup_required
@login_required @login_required
@ -68,6 +69,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
return response return response
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/callback")
class DatasourceOAuthCallback(Resource): class DatasourceOAuthCallback(Resource):
@setup_required @setup_required
def get(self, provider_id: str): def get(self, provider_id: str):
@ -123,6 +125,7 @@ class DatasourceOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource): class DatasourceAuth(Resource):
@setup_required @setup_required
@login_required @login_required
@ -165,6 +168,7 @@ class DatasourceAuth(Resource):
return {"result": datasources}, 200 return {"result": datasources}, 200
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource): class DatasourceAuthDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -188,6 +192,7 @@ class DatasourceAuthDeleteApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource): class DatasourceAuthUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -213,6 +218,7 @@ class DatasourceAuthUpdateApi(Resource):
return {"result": "success"}, 201 return {"result": "success"}, 201
@console_ns.route("/auth/plugin/datasource/list")
class DatasourceAuthListApi(Resource): class DatasourceAuthListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -225,6 +231,7 @@ class DatasourceAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200 return {"result": jsonable_encoder(datasources)}, 200
@console_ns.route("/auth/plugin/datasource/default-list")
class DatasourceHardCodeAuthListApi(Resource): class DatasourceHardCodeAuthListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -237,6 +244,7 @@ class DatasourceHardCodeAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200 return {"result": jsonable_encoder(datasources)}, 200
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource): class DatasourceAuthOauthCustomClient(Resource):
@setup_required @setup_required
@login_required @login_required
@ -271,6 +279,7 @@ class DatasourceAuthOauthCustomClient(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource): class DatasourceAuthDefaultApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -291,6 +300,7 @@ class DatasourceAuthDefaultApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource): class DatasourceUpdateProviderNameApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -311,52 +321,3 @@ class DatasourceUpdateProviderNameApi(Resource):
credential_id=args["credential_id"], credential_id=args["credential_id"],
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
api.add_resource(
DatasourcePluginOAuthAuthorizationUrl,
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
)
api.add_resource(
DatasourceOAuthCallback,
"/oauth/plugin/<path:provider_id>/datasource/callback",
)
api.add_resource(
DatasourceAuth,
"/auth/plugin/datasource/<path:provider_id>",
)
api.add_resource(
DatasourceAuthUpdateApi,
"/auth/plugin/datasource/<path:provider_id>/update",
)
api.add_resource(
DatasourceAuthDeleteApi,
"/auth/plugin/datasource/<path:provider_id>/delete",
)
api.add_resource(
DatasourceAuthListApi,
"/auth/plugin/datasource/list",
)
api.add_resource(
DatasourceHardCodeAuthListApi,
"/auth/plugin/datasource/default-list",
)
api.add_resource(
DatasourceAuthOauthCustomClient,
"/auth/plugin/datasource/<path:provider_id>/custom-client",
)
api.add_resource(
DatasourceAuthDefaultApi,
"/auth/plugin/datasource/<path:provider_id>/default",
)
api.add_resource(
DatasourceUpdateProviderNameApi,
"/auth/plugin/datasource/<path:provider_id>/update-name",
)

View File

@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore
) )
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required from libs.login import current_user, login_required
@ -13,6 +13,7 @@ from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline import RagPipelineService
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource): class DataSourceContentPreviewApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -49,9 +50,3 @@ class DataSourceContentPreviewApi(Resource):
credential_id=args.get("credential_id"), credential_id=args.get("credential_id"),
) )
return preview_content, 200 return preview_content, 200
api.add_resource(
DataSourceContentPreviewApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
)

View File

@ -4,7 +4,7 @@ from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
enterprise_license_required, enterprise_license_required,
@ -20,18 +20,19 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _validate_name(name): def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40: if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.") raise ValueError("Name must be between 1 to 40 characters.")
return name return name
def _validate_description_length(description): def _validate_description_length(description: str) -> str:
if len(description) > 400: if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.") raise ValueError("Description cannot exceed 400 characters.")
return description return description
@console_ns.route("/rag/pipeline/templates")
class PipelineTemplateListApi(Resource): class PipelineTemplateListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -45,6 +46,7 @@ class PipelineTemplateListApi(Resource):
return pipeline_templates, 200 return pipeline_templates, 200
@console_ns.route("/rag/pipeline/templates/<string:template_id>")
class PipelineTemplateDetailApi(Resource): class PipelineTemplateDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -57,6 +59,7 @@ class PipelineTemplateDetailApi(Resource):
return pipeline_template, 200 return pipeline_template, 200
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
class CustomizedPipelineTemplateApi(Resource): class CustomizedPipelineTemplateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -73,7 +76,7 @@ class CustomizedPipelineTemplateApi(Resource):
) )
parser.add_argument( parser.add_argument(
"description", "description",
type=str, type=_validate_description_length,
nullable=True, nullable=True,
required=False, required=False,
default="", default="",
@ -85,7 +88,7 @@ class CustomizedPipelineTemplateApi(Resource):
nullable=True, nullable=True,
) )
args = parser.parse_args() args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args) pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200 return 200
@ -112,6 +115,7 @@ class CustomizedPipelineTemplateApi(Resource):
return {"data": template.yaml_content}, 200 return {"data": template.yaml_content}, 200
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
class PublishCustomizedPipelineTemplateApi(Resource): class PublishCustomizedPipelineTemplateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -129,7 +133,7 @@ class PublishCustomizedPipelineTemplateApi(Resource):
) )
parser.add_argument( parser.add_argument(
"description", "description",
type=str, type=_validate_description_length,
nullable=True, nullable=True,
required=False, required=False,
default="", default="",
@ -144,21 +148,3 @@ class PublishCustomizedPipelineTemplateApi(Resource):
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args) rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
return {"result": "success"} return {"result": "success"}
api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",
)
api.add_resource(
PipelineTemplateDetailApi,
"/rag/pipeline/templates/<string:template_id>",
)
api.add_resource(
CustomizedPipelineTemplateApi,
"/rag/pipeline/customized/templates/<string:template_id>",
)
api.add_resource(
PublishCustomizedPipelineTemplateApi,
"/rag/pipelines/<string:pipeline_id>/customized/publish",
)

View File

@ -1,10 +1,10 @@
from flask_login import current_user # type: ignore # type: ignore from flask_login import current_user
from flask_restx import Resource, marshal, reqparse # type: ignore from flask_restx import Resource, marshal, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
import services import services
from controllers.console import api from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
@ -20,18 +20,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
def _validate_name(name): @console_ns.route("/rag/pipeline/dataset")
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class CreateRagPipelineDatasetApi(Resource): class CreateRagPipelineDatasetApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -84,6 +73,7 @@ class CreateRagPipelineDatasetApi(Resource):
return import_info, 201 return import_info, 201
@console_ns.route("/rag/pipeline/empty-dataset")
class CreateEmptyRagPipelineDatasetApi(Resource): class CreateEmptyRagPipelineDatasetApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -108,7 +98,3 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
), ),
) )
return marshal(dataset, dataset_detail_fields), 201 return marshal(dataset, dataset_detail_fields), 201
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")

View File

@ -1,24 +1,22 @@
import logging import logging
from typing import Any, NoReturn from typing import NoReturn
from flask import Response from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
DraftWorkflowNotExist, DraftWorkflowNotExist,
) )
from controllers.console.app.workflow_draft_variable import ( from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS, _WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
) )
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db from extensions.ext_database import db
@ -34,32 +32,6 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
value = variable.get_value()
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
# Refresh the url signature before returning it to client.
if isinstance(value, FileSegment):
file = value.value
file.remote_url = file.generate_url()
elif isinstance(value, ArrayFileSegment):
files = value.value
for file in files:
file.remote_url = file.generate_url()
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser(): def _create_pagination_parser():
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument( parser.add_argument(
@ -104,13 +76,14 @@ def _api_prerequisite(f):
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not isinstance(current_user, Account) or not current_user.is_editor: if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
return f(*args, **kwargs) return f(*args, **kwargs)
return wrapper return wrapper
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables")
class RagPipelineVariableCollectionApi(Resource): class RagPipelineVariableCollectionApi(Resource):
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
@ -168,6 +141,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
return None return None
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables")
class RagPipelineNodeVariableCollectionApi(Resource): class RagPipelineNodeVariableCollectionApi(Resource):
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@ -190,6 +164,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
return Response("", 204) return Response("", 204)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>")
class RagPipelineVariableApi(Resource): class RagPipelineVariableApi(Resource):
_PATCH_NAME_FIELD = "name" _PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value" _PATCH_VALUE_FIELD = "value"
@ -284,6 +259,7 @@ class RagPipelineVariableApi(Resource):
return Response("", 204) return Response("", 204)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class RagPipelineVariableResetApi(Resource): class RagPipelineVariableResetApi(Resource):
@_api_prerequisite @_api_prerequisite
def put(self, pipeline: Pipeline, variable_id: str): def put(self, pipeline: Pipeline, variable_id: str):
@ -325,6 +301,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
return draft_vars return draft_vars
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables")
class RagPipelineSystemVariableCollectionApi(Resource): class RagPipelineSystemVariableCollectionApi(Resource):
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@ -332,6 +309,7 @@ class RagPipelineSystemVariableCollectionApi(Resource):
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID) return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables")
class RagPipelineEnvironmentVariableCollectionApi(Resource): class RagPipelineEnvironmentVariableCollectionApi(Resource):
@_api_prerequisite @_api_prerequisite
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
@ -364,26 +342,3 @@ class RagPipelineEnvironmentVariableCollectionApi(Resource):
) )
return {"items": env_vars_list} return {"items": env_vars_list}
api.add_resource(
RagPipelineVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
)
api.add_resource(
RagPipelineNodeVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
)
api.add_resource(
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
)
api.add_resource(
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
)
api.add_resource(
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
)
api.add_resource(
RagPipelineEnvironmentVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
)

View File

@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
@ -20,6 +20,7 @@ from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource): class RagPipelineImportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -59,13 +60,14 @@ class RagPipelineImportApi(Resource):
# Return appropriate status code based on result # Return appropriate status code based on result
status = result.status status = result.status
if status == ImportStatus.FAILED.value: if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value: elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202 return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
class RagPipelineImportConfirmApi(Resource): class RagPipelineImportConfirmApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -85,11 +87,12 @@ class RagPipelineImportConfirmApi(Resource):
session.commit() session.commit()
# Return appropriate status code based on result # Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value: if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400 return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/imports/<string:pipeline_id>/check-dependencies")
class RagPipelineImportCheckDependenciesApi(Resource): class RagPipelineImportCheckDependenciesApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -107,6 +110,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
return result.model_dump(mode="json"), 200 return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/<string:pipeline_id>/exports")
class RagPipelineExportApi(Resource): class RagPipelineExportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -128,22 +132,3 @@ class RagPipelineExportApi(Resource):
) )
return {"data": result}, 200 return {"data": result}, 200
# Import Rag Pipeline
api.add_resource(
RagPipelineImportApi,
"/rag/pipelines/imports",
)
api.add_resource(
RagPipelineImportConfirmApi,
"/rag/pipelines/imports/<string:import_id>/confirm",
)
api.add_resource(
RagPipelineImportCheckDependenciesApi,
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
)
api.add_resource(
RagPipelineExportApi,
"/rag/pipelines/<string:pipeline_id>/exports",
)

View File

@ -9,8 +9,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from configs import dify_config from controllers.console import console_ns
from controllers.console import api
from controllers.console.app.error import ( from controllers.console.app.error import (
ConversationCompletedError, ConversationCompletedError,
DraftWorkflowNotExist, DraftWorkflowNotExist,
@ -51,6 +50,7 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
class DraftRagPipelineApi(Resource): class DraftRagPipelineApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -148,6 +148,7 @@ class DraftRagPipelineApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource): class RagPipelineDraftRunIterationNodeApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -182,6 +183,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
raise InternalServerError() raise InternalServerError()
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource): class RagPipelineDraftRunLoopNodeApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -216,6 +218,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
raise InternalServerError() raise InternalServerError()
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource): class DraftRagPipelineRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -250,6 +253,7 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description) raise InvokeRateLimitHttpError(ex.description)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource): class PublishedRagPipelineRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -370,6 +374,7 @@ class PublishedRagPipelineRunApi(Resource):
# #
# return result # return result
# #
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource): class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -412,6 +417,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
) )
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource): class RagPipelineDraftDatasourceNodeRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -454,6 +460,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
) )
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource): class RagPipelineDraftNodeRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -487,6 +494,7 @@ class RagPipelineDraftNodeRunApi(Resource):
return workflow_node_execution return workflow_node_execution
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop")
class RagPipelineTaskStopApi(Resource): class RagPipelineTaskStopApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -505,6 +513,7 @@ class RagPipelineTaskStopApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/publish")
class PublishedRagPipelineApi(Resource): class PublishedRagPipelineApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -560,6 +569,7 @@ class PublishedRagPipelineApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs")
class DefaultRagPipelineBlockConfigsApi(Resource): class DefaultRagPipelineBlockConfigsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -578,6 +588,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
return rag_pipeline_service.get_default_block_configs() return rag_pipeline_service.get_default_block_configs()
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource): class DefaultRagPipelineBlockConfigApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -609,18 +620,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
class RagPipelineConfigApi(Resource): @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
"""Resource for rag pipeline configuration."""
@setup_required
@login_required
@account_initialization_required
def get(self, pipeline_id):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}
class PublishedAllRagPipelineApi(Resource): class PublishedAllRagPipelineApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -669,6 +669,7 @@ class PublishedAllRagPipelineApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource): class RagPipelineByIdApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -726,6 +727,7 @@ class RagPipelineByIdApi(Resource):
return workflow return workflow
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource): class PublishedRagPipelineSecondStepApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -751,6 +753,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource): class PublishedRagPipelineFirstStepApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -776,6 +779,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource): class DraftRagPipelineFirstStepApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -801,6 +805,7 @@ class DraftRagPipelineFirstStepApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource): class DraftRagPipelineSecondStepApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -827,6 +832,7 @@ class DraftRagPipelineSecondStepApi(Resource):
} }
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource): class RagPipelineWorkflowRunListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -848,6 +854,7 @@ class RagPipelineWorkflowRunListApi(Resource):
return result return result
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>")
class RagPipelineWorkflowRunDetailApi(Resource): class RagPipelineWorkflowRunDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -866,6 +873,7 @@ class RagPipelineWorkflowRunDetailApi(Resource):
return workflow_run return workflow_run
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions")
class RagPipelineWorkflowRunNodeExecutionListApi(Resource): class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -889,6 +897,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
return {"data": node_executions} return {"data": node_executions}
@console_ns.route("/rag/pipelines/datasource-plugins")
class DatasourceListApi(Resource): class DatasourceListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -904,6 +913,7 @@ class DatasourceListApi(Resource):
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id)) return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
class RagPipelineWorkflowLastRunApi(Resource): class RagPipelineWorkflowLastRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -925,6 +935,7 @@ class RagPipelineWorkflowLastRunApi(Resource):
return node_exec return node_exec
@console_ns.route("/rag/pipelines/transform/datasets/<uuid:dataset_id>")
class RagPipelineTransformApi(Resource): class RagPipelineTransformApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -942,6 +953,7 @@ class RagPipelineTransformApi(Resource):
return result return result
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource): class RagPipelineDatasourceVariableApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -971,6 +983,7 @@ class RagPipelineDatasourceVariableApi(Resource):
return workflow_node_execution return workflow_node_execution
@console_ns.route("/rag/pipelines/recommended-plugins")
class RagPipelineRecommendedPluginApi(Resource): class RagPipelineRecommendedPluginApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -979,118 +992,3 @@ class RagPipelineRecommendedPluginApi(Resource):
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins() recommended_plugins = rag_pipeline_service.get_recommended_plugins()
return recommended_plugins return recommended_plugins
api.add_resource(
DraftRagPipelineApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft",
)
api.add_resource(
RagPipelineConfigApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/config",
)
api.add_resource(
DraftRagPipelineRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run",
)
api.add_resource(
PublishedRagPipelineRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/run",
)
api.add_resource(
RagPipelineTaskStopApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop",
)
api.add_resource(
RagPipelineDraftNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelinePublishedDatasourceNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelineDraftDatasourceNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelineDraftRunIterationNodeApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelineDraftRunLoopNodeApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
PublishedRagPipelineApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/publish",
)
api.add_resource(
PublishedAllRagPipelineApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows",
)
api.add_resource(
DefaultRagPipelineBlockConfigsApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs",
)
api.add_resource(
DefaultRagPipelineBlockConfigApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>",
)
api.add_resource(
RagPipelineByIdApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>",
)
api.add_resource(
RagPipelineWorkflowRunListApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs",
)
api.add_resource(
RagPipelineWorkflowRunDetailApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>",
)
api.add_resource(
RagPipelineWorkflowRunNodeExecutionListApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions",
)
api.add_resource(
DatasourceListApi,
"/rag/pipelines/datasource-plugins",
)
api.add_resource(
PublishedRagPipelineSecondStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters",
)
api.add_resource(
PublishedRagPipelineFirstStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters",
)
api.add_resource(
DraftRagPipelineSecondStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters",
)
api.add_resource(
DraftRagPipelineFirstStepApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters",
)
api.add_resource(
RagPipelineWorkflowLastRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run",
)
api.add_resource(
RagPipelineTransformApi,
"/rag/pipelines/transform/datasets/<uuid:dataset_id>",
)
api.add_resource(
RagPipelineDatasourceVariableApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect",
)
api.add_resource(
RagPipelineRecommendedPluginApi,
"/rag/pipelines/recommended-plugins",
)

View File

@ -26,9 +26,15 @@ from services.errors.audio import (
UnsupportedAudioTypeServiceError, UnsupportedAudioTypeServiceError,
) )
from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/audio-to-text",
endpoint="installed_app_audio",
)
class ChatAudioApi(InstalledAppResource): class ChatAudioApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
@ -65,6 +71,10 @@ class ChatAudioApi(InstalledAppResource):
raise InternalServerError() raise InternalServerError()
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/text-to-audio",
endpoint="installed_app_text",
)
class ChatTextApi(InstalledAppResource): class ChatTextApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
from flask_restx import reqparse from flask_restx import reqparse

View File

@ -33,10 +33,16 @@ from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# define completion api for user # define completion api for user
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/completion-messages",
endpoint="installed_app_completion",
)
class CompletionApi(InstalledAppResource): class CompletionApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
@ -87,6 +93,10 @@ class CompletionApi(InstalledAppResource):
raise InternalServerError() raise InternalServerError()
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
endpoint="installed_app_stop_completion",
)
class CompletionStopApi(InstalledAppResource): class CompletionStopApi(InstalledAppResource):
def post(self, installed_app, task_id): def post(self, installed_app, task_id):
app_model = installed_app.app app_model = installed_app.app
@ -100,6 +110,10 @@ class CompletionStopApi(InstalledAppResource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/chat-messages",
endpoint="installed_app_chat_completion",
)
class ChatApi(InstalledAppResource): class ChatApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
@ -153,6 +167,10 @@ class ChatApi(InstalledAppResource):
raise InternalServerError() raise InternalServerError()
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
endpoint="installed_app_stop_chat_completion",
)
class ChatStopApi(InstalledAppResource): class ChatStopApi(InstalledAppResource):
def post(self, installed_app, task_id): def post(self, installed_app, task_id):
app_model = installed_app.app app_model = installed_app.app

View File

@ -16,7 +16,13 @@ from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
from services.web_conversation_service import WebConversationService from services.web_conversation_service import WebConversationService
from .. import console_ns
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations",
endpoint="installed_app_conversations",
)
class ConversationListApi(InstalledAppResource): class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):
@ -52,6 +58,10 @@ class ConversationListApi(InstalledAppResource):
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
endpoint="installed_app_conversation",
)
class ConversationApi(InstalledAppResource): class ConversationApi(InstalledAppResource):
def delete(self, installed_app, c_id): def delete(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
@ -70,6 +80,10 @@ class ConversationApi(InstalledAppResource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
endpoint="installed_app_conversation_rename",
)
class ConversationRenameApi(InstalledAppResource): class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, installed_app, c_id): def post(self, installed_app, c_id):
@ -95,6 +109,10 @@ class ConversationRenameApi(InstalledAppResource):
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
endpoint="installed_app_conversation_pin",
)
class ConversationPinApi(InstalledAppResource): class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id): def patch(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
@ -114,6 +132,10 @@ class ConversationPinApi(InstalledAppResource):
return {"result": "success"} return {"result": "success"}
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
endpoint="installed_app_conversation_unpin",
)
class ConversationUnPinApi(InstalledAppResource): class ConversationUnPinApi(InstalledAppResource):
def patch(self, installed_app, c_id): def patch(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app

View File

@ -6,7 +6,7 @@ from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_, select from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.console import api from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db from extensions.ext_database import db
@ -22,6 +22,7 @@ from services.feature_service import FeatureService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource): class InstalledAppsListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -154,6 +155,7 @@ class InstalledAppsListApi(Resource):
return {"message": "App installed successfully"} return {"message": "App installed successfully"}
@console_ns.route("/installed-apps/<uuid:installed_app_id>")
class InstalledAppApi(InstalledAppResource): class InstalledAppApi(InstalledAppResource):
""" """
update and delete an installed app update and delete an installed app
@ -185,7 +187,3 @@ class InstalledAppApi(InstalledAppResource):
db.session.commit() db.session.commit()
return {"result": "success", "message": "App info updated successfully"} return {"result": "success", "message": "App info updated successfully"}
api.add_resource(InstalledAppsListApi, "/installed-apps")
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")

View File

@ -36,9 +36,15 @@ from services.errors.message import (
) )
from services.message_service import MessageService from services.message_service import MessageService
from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/messages",
endpoint="installed_app_messages",
)
class MessageListApi(InstalledAppResource): class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):
@ -66,6 +72,10 @@ class MessageListApi(InstalledAppResource):
raise NotFound("First Message Not Exists.") raise NotFound("First Message Not Exists.")
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
endpoint="installed_app_message_feedback",
)
class MessageFeedbackApi(InstalledAppResource): class MessageFeedbackApi(InstalledAppResource):
def post(self, installed_app, message_id): def post(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app
@ -93,6 +103,10 @@ class MessageFeedbackApi(InstalledAppResource):
return {"result": "success"} return {"result": "success"}
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
endpoint="installed_app_more_like_this",
)
class MessageMoreLikeThisApi(InstalledAppResource): class MessageMoreLikeThisApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app
@ -139,6 +153,10 @@ class MessageMoreLikeThisApi(InstalledAppResource):
raise InternalServerError() raise InternalServerError()
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="installed_app_suggested_question",
)
class MessageSuggestedQuestionApi(InstalledAppResource): class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app

View File

@ -1,7 +1,7 @@
from flask_restx import marshal_with from flask_restx import marshal_with
from controllers.common import fields from controllers.common import fields
from controllers.console import api from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
@ -9,6 +9,7 @@ from models.model import AppMode, InstalledApp
from services.app_service import AppService from services.app_service import AppService
@console_ns.route("/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters")
class AppParameterApi(InstalledAppResource): class AppParameterApi(InstalledAppResource):
"""Resource for app variables.""" """Resource for app variables."""
@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
class ExploreAppMetaApi(InstalledAppResource): class ExploreAppMetaApi(InstalledAppResource):
def get(self, installed_app: InstalledApp): def get(self, installed_app: InstalledApp):
"""Get app meta""" """Get app meta"""
@ -46,9 +48,3 @@ class ExploreAppMetaApi(InstalledAppResource):
if not app_model: if not app_model:
raise ValueError("App not found") raise ValueError("App not found")
return AppService().get_app_meta(app_model) return AppService().get_app_meta(app_model)
api.add_resource(
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
)
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField from libs.helper import AppIconUrlField
from libs.login import current_user, login_required from libs.login import current_user, login_required
@ -35,6 +35,7 @@ recommended_app_list_fields = {
} }
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource): class RecommendedAppListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -56,13 +57,10 @@ class RecommendedAppListApi(Resource):
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix) return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
@console_ns.route("/explore/apps/<uuid:app_id>")
class RecommendedAppApi(Resource): class RecommendedAppApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
app_id = str(app_id) app_id = str(app_id)
return RecommendedAppService.get_recommend_app_detail(app_id) return RecommendedAppService.get_recommend_app_detail(app_id)
api.add_resource(RecommendedAppListApi, "/explore/apps")
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")

View File

@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
@ -25,6 +25,7 @@ message_fields = {
} }
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource): class SavedMessageListApi(InstalledAppResource):
saved_message_infinite_scroll_pagination_fields = { saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer, "limit": fields.Integer,
@ -66,6 +67,9 @@ class SavedMessageListApi(InstalledAppResource):
return {"result": "success"} return {"result": "success"}
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>", endpoint="installed_app_saved_message"
)
class SavedMessageApi(InstalledAppResource): class SavedMessageApi(InstalledAppResource):
def delete(self, installed_app, message_id): def delete(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app
@ -80,15 +84,3 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id) SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204 return {"result": "success"}, 204
api.add_resource(
SavedMessageListApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages",
endpoint="installed_app_saved_messages",
)
api.add_resource(
SavedMessageApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
endpoint="installed_app_saved_message",
)

View File

@ -27,9 +27,12 @@ from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource): class InstalledAppWorkflowRunApi(InstalledAppResource):
def post(self, installed_app: InstalledApp): def post(self, installed_app: InstalledApp):
""" """
@ -70,6 +73,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
raise InternalServerError() raise InternalServerError()
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop")
class InstalledAppWorkflowTaskStopApi(InstalledAppResource): class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
def post(self, installed_app: InstalledApp, task_id: str): def post(self, installed_app: InstalledApp, task_id: str):
""" """

View File

@ -26,9 +26,12 @@ from libs.login import login_required
from models import Account from models import Account
from services.file_service import FileService from services.file_service import FileService
from . import console_ns
PREVIEW_WORDS_LIMIT = 3000 PREVIEW_WORDS_LIMIT = 3000
@console_ns.route("/files/upload")
class FileApi(Resource): class FileApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -88,6 +91,7 @@ class FileApi(Resource):
return upload_file, 201 return upload_file, 201
@console_ns.route("/files/<uuid:file_id>/preview")
class FilePreviewApi(Resource): class FilePreviewApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -98,6 +102,7 @@ class FilePreviewApi(Resource):
return {"content": text} return {"content": text}
@console_ns.route("/files/support-type")
class FileSupportTypeApi(Resource): class FileSupportTypeApi(Resource):
@setup_required @setup_required
@login_required @login_required

View File

@ -19,7 +19,10 @@ from fields.file_fields import file_fields_with_signed_url, remote_file_info_fie
from models.account import Account from models.account import Account
from services.file_service import FileService from services.file_service import FileService
from . import console_ns
@console_ns.route("/remote-files/<path:url>")
class RemoteFileInfoApi(Resource): class RemoteFileInfoApi(Resource):
@marshal_with(remote_file_info_fields) @marshal_with(remote_file_info_fields)
def get(self, url): def get(self, url):
@ -35,6 +38,7 @@ class RemoteFileInfoApi(Resource):
} }
@console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource): class RemoteFileUploadApi(Resource):
@marshal_with(file_fields_with_signed_url) @marshal_with(file_fields_with_signed_url)
def post(self): def post(self):

View File

@ -2,7 +2,6 @@ import logging
from flask_restx import Resource from flask_restx import Resource
from controllers.console import api
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
setup_required, setup_required,
@ -10,9 +9,12 @@ from controllers.console.wraps import (
from core.schemas.schema_manager import SchemaManager from core.schemas.schema_manager import SchemaManager
from libs.login import login_required from libs.login import login_required
from . import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route("/spec/schema-definitions")
class SpecSchemaDefinitionsApi(Resource): class SpecSchemaDefinitionsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -30,6 +32,3 @@ class SpecSchemaDefinitionsApi(Resource):
logger.exception("Failed to get schema definitions from local registry") logger.exception("Failed to get schema definitions from local registry")
# Return empty array as fallback # Return empty array as fallback
return [], 200 return [], 200
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")

View File

@ -3,7 +3,7 @@ from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields from fields.tag_fields import dataset_tag_fields
from libs.login import login_required from libs.login import login_required
@ -17,6 +17,7 @@ def _validate_name(name):
return name return name
@console_ns.route("/tags")
class TagListApi(Resource): class TagListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -52,6 +53,7 @@ class TagListApi(Resource):
return response, 200 return response, 200
@console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource): class TagUpdateDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -89,6 +91,7 @@ class TagUpdateDeleteApi(Resource):
return 204 return 204
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource): class TagBindingCreateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -114,6 +117,7 @@ class TagBindingCreateApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource): class TagBindingDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -133,9 +137,3 @@ class TagBindingDeleteApi(Resource):
TagService.delete_tag_binding(args) TagService.delete_tag_binding(args)
return {"result": "success"}, 200 return {"result": "success"}, 200
api.add_resource(TagListApi, "/tags")
api.add_resource(TagUpdateDeleteApi, "/tags/<uuid:tag_id>")
api.add_resource(TagBindingCreateApi, "/tag-bindings/create")
api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove")

View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
import requests import httpx
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from packaging import version from packaging import version
@ -57,7 +57,11 @@ class VersionApi(Resource):
return result return result
try: try:
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10)) response = httpx.get(
check_update_url,
params={"current_version": args["current_version"]},
timeout=httpx.Timeout(connect=3, read=10),
)
except Exception as error: except Exception as error:
logger.warning("Check update version error: %s.", str(error)) logger.warning("Check update version error: %s.", str(error))
result["version"] = args["current_version"] result["version"] = args["current_version"]

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
EmailAlreadyInUseError, EmailAlreadyInUseError,
EmailChangeLimitError, EmailChangeLimitError,
@ -45,6 +45,7 @@ from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@console_ns.route("/account/init")
class AccountInitApi(Resource): class AccountInitApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -97,6 +98,7 @@ class AccountInitApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/account/profile")
class AccountProfileApi(Resource): class AccountProfileApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -109,6 +111,7 @@ class AccountProfileApi(Resource):
return current_user return current_user
@console_ns.route("/account/name")
class AccountNameApi(Resource): class AccountNameApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -130,6 +133,7 @@ class AccountNameApi(Resource):
return updated_account return updated_account
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource): class AccountAvatarApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -147,6 +151,7 @@ class AccountAvatarApi(Resource):
return updated_account return updated_account
@console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource): class AccountInterfaceLanguageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -164,6 +169,7 @@ class AccountInterfaceLanguageApi(Resource):
return updated_account return updated_account
@console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource): class AccountInterfaceThemeApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -181,6 +187,7 @@ class AccountInterfaceThemeApi(Resource):
return updated_account return updated_account
@console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource): class AccountTimezoneApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -202,6 +209,7 @@ class AccountTimezoneApi(Resource):
return updated_account return updated_account
@console_ns.route("/account/password")
class AccountPasswordApi(Resource): class AccountPasswordApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -227,6 +235,7 @@ class AccountPasswordApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/account/integrates")
class AccountIntegrateApi(Resource): class AccountIntegrateApi(Resource):
integrate_fields = { integrate_fields = {
"provider": fields.String, "provider": fields.String,
@ -283,6 +292,7 @@ class AccountIntegrateApi(Resource):
return {"data": integrate_data} return {"data": integrate_data}
@console_ns.route("/account/delete/verify")
class AccountDeleteVerifyApi(Resource): class AccountDeleteVerifyApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -298,6 +308,7 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
@console_ns.route("/account/delete")
class AccountDeleteApi(Resource): class AccountDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -320,6 +331,7 @@ class AccountDeleteApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource): class AccountDeleteUpdateFeedbackApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
@ -333,6 +345,7 @@ class AccountDeleteUpdateFeedbackApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/account/education/verify")
class EducationVerifyApi(Resource): class EducationVerifyApi(Resource):
verify_fields = { verify_fields = {
"token": fields.String, "token": fields.String,
@ -352,6 +365,7 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email) return BillingService.EducationIdentity.verify(account.id, account.email)
@console_ns.route("/account/education")
class EducationApi(Resource): class EducationApi(Resource):
status_fields = { status_fields = {
"result": fields.Boolean, "result": fields.Boolean,
@ -396,6 +410,7 @@ class EducationApi(Resource):
return res return res
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource): class EducationAutoCompleteApi(Resource):
data_fields = { data_fields = {
"data": fields.List(fields.String), "data": fields.List(fields.String),
@ -419,6 +434,7 @@ class EducationAutoCompleteApi(Resource):
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"]) return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
@console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource): class ChangeEmailSendEmailApi(Resource):
@enable_change_email @enable_change_email
@setup_required @setup_required
@ -467,6 +483,7 @@ class ChangeEmailSendEmailApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
@console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource): class ChangeEmailCheckApi(Resource):
@enable_change_email @enable_change_email
@setup_required @setup_required
@ -508,6 +525,7 @@ class ChangeEmailCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource): class ChangeEmailResetApi(Resource):
@enable_change_email @enable_change_email
@setup_required @setup_required
@ -547,6 +565,7 @@ class ChangeEmailResetApi(Resource):
return updated_account return updated_account
@console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource): class CheckEmailUnique(Resource):
@setup_required @setup_required
def post(self): def post(self):
@ -558,28 +577,3 @@ class CheckEmailUnique(Resource):
if not AccountService.check_email_unique(args["email"]): if not AccountService.check_email_unique(args["email"]):
raise EmailAlreadyInUseError() raise EmailAlreadyInUseError()
return {"result": "success"} return {"result": "success"}
# Register API resources
api.add_resource(AccountInitApi, "/account/init")
api.add_resource(AccountProfileApi, "/account/profile")
api.add_resource(AccountNameApi, "/account/name")
api.add_resource(AccountAvatarApi, "/account/avatar")
api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language")
api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
api.add_resource(AccountTimezoneApi, "/account/timezone")
api.add_resource(AccountPasswordApi, "/account/password")
api.add_resource(AccountIntegrateApi, "/account/integrates")
api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
api.add_resource(AccountDeleteApi, "/account/delete")
api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
api.add_resource(EducationVerifyApi, "/account/education/verify")
api.add_resource(EducationApi, "/account/education")
api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete")
# Change email
api.add_resource(ChangeEmailSendEmailApi, "/account/change-email")
api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity")
api.add_resource(ChangeEmailResetApi, "/account/change-email/reset")
api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique")
# api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -10,6 +10,9 @@ from models.account import Account, TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService from services.model_load_balancing_service import ModelLoadBalancingService
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
)
class LoadBalancingCredentialsValidateApi(Resource): class LoadBalancingCredentialsValidateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -61,6 +64,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
return response return response
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
)
class LoadBalancingConfigCredentialsValidateApi(Resource): class LoadBalancingConfigCredentialsValidateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -111,15 +117,3 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
response["error"] = error response["error"] = error
return response return response
# Load Balancing Config
api.add_resource(
LoadBalancingCredentialsValidateApi,
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
)
api.add_resource(
LoadBalancingConfigCredentialsValidateApi,
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
)

View File

@ -6,7 +6,7 @@ from flask_restx import Resource, marshal_with, reqparse
import services import services
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
CannotTransferOwnerToSelfError, CannotTransferOwnerToSelfError,
EmailCodeError, EmailCodeError,
@ -33,6 +33,7 @@ from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService from services.feature_service import FeatureService
@console_ns.route("/workspaces/current/members")
class MemberListApi(Resource): class MemberListApi(Resource):
"""List all members of current tenant.""" """List all members of current tenant."""
@ -49,6 +50,7 @@ class MemberListApi(Resource):
return {"result": "success", "accounts": members}, 200 return {"result": "success", "accounts": members}, 200
@console_ns.route("/workspaces/current/members/invite-email")
class MemberInviteEmailApi(Resource): class MemberInviteEmailApi(Resource):
"""Invite a new member by email.""" """Invite a new member by email."""
@ -111,6 +113,7 @@ class MemberInviteEmailApi(Resource):
}, 201 }, 201
@console_ns.route("/workspaces/current/members/<uuid:member_id>")
class MemberCancelInviteApi(Resource): class MemberCancelInviteApi(Resource):
"""Cancel an invitation by member id.""" """Cancel an invitation by member id."""
@ -143,6 +146,7 @@ class MemberCancelInviteApi(Resource):
}, 200 }, 200
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
class MemberUpdateRoleApi(Resource): class MemberUpdateRoleApi(Resource):
"""Update member role.""" """Update member role."""
@ -177,6 +181,7 @@ class MemberUpdateRoleApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/workspaces/current/dataset-operators")
class DatasetOperatorMemberListApi(Resource): class DatasetOperatorMemberListApi(Resource):
"""List all members of current tenant.""" """List all members of current tenant."""
@ -193,6 +198,7 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200 return {"result": "success", "accounts": members}, 200
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
class SendOwnerTransferEmailApi(Resource): class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email.""" """Send owner transfer email."""
@ -233,6 +239,7 @@ class SendOwnerTransferEmailApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
@console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource): class OwnerTransferCheckApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -278,6 +285,7 @@ class OwnerTransferCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
class OwnerTransfer(Resource): class OwnerTransfer(Resource):
@setup_required @setup_required
@login_required @login_required
@ -339,14 +347,3 @@ class OwnerTransfer(Resource):
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success"} return {"result": "success"}
api.add_resource(MemberListApi, "/workspaces/current/members")
api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email")
api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/<uuid:member_id>")
api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members/<uuid:member_id>/update-role")
api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators")
# owner transfer
api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email")
api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check")
api.add_resource(OwnerTransfer, "/workspaces/current/members/<uuid:member_id>/owner-transfer")

View File

@ -5,7 +5,7 @@ from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -17,6 +17,7 @@ from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService from services.model_provider_service import ModelProviderService
@console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource): class ModelProviderListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -45,6 +46,7 @@ class ModelProviderListApi(Resource):
return jsonable_encoder({"data": provider_list}) return jsonable_encoder({"data": provider_list})
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
class ModelProviderCredentialApi(Resource): class ModelProviderCredentialApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -151,6 +153,7 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
class ModelProviderCredentialSwitchApi(Resource): class ModelProviderCredentialSwitchApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -175,6 +178,7 @@ class ModelProviderCredentialSwitchApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
class ModelProviderValidateApi(Resource): class ModelProviderValidateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -211,6 +215,7 @@ class ModelProviderValidateApi(Resource):
return response return response
@console_ns.route("/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>")
class ModelProviderIconApi(Resource): class ModelProviderIconApi(Resource):
""" """
Get model provider icon Get model provider icon
@ -229,6 +234,7 @@ class ModelProviderIconApi(Resource):
return send_file(io.BytesIO(icon), mimetype=mimetype) return send_file(io.BytesIO(icon), mimetype=mimetype)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource): class PreferredProviderTypeUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -262,6 +268,7 @@ class PreferredProviderTypeUpdateApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/checkout-url")
class ModelProviderPaymentCheckoutUrlApi(Resource): class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -281,21 +288,3 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
prefilled_email=current_user.email, prefilled_email=current_user.email,
) )
return data return data
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
api.add_resource(
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
)
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
api.add_resource(
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
)
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
api.add_resource(
ModelProviderIconApi,
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
)

View File

@ -4,7 +4,7 @@ from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -17,6 +17,7 @@ from services.model_provider_service import ModelProviderService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource): class DefaultModelApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -85,6 +86,7 @@ class DefaultModelApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
class ModelProviderModelApi(Resource): class ModelProviderModelApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -187,6 +189,7 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
class ModelProviderModelCredentialApi(Resource): class ModelProviderModelCredentialApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -364,6 +367,7 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}, 204 return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource): class ModelProviderModelCredentialSwitchApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -395,6 +399,9 @@ class ModelProviderModelCredentialSwitchApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
)
class ModelProviderModelEnableApi(Resource): class ModelProviderModelEnableApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -422,6 +429,9 @@ class ModelProviderModelEnableApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
)
class ModelProviderModelDisableApi(Resource): class ModelProviderModelDisableApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -449,6 +459,7 @@ class ModelProviderModelDisableApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
class ModelProviderModelValidateApi(Resource): class ModelProviderModelValidateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -494,6 +505,7 @@ class ModelProviderModelValidateApi(Resource):
return response return response
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource): class ModelProviderModelParameterRuleApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -513,6 +525,7 @@ class ModelProviderModelParameterRuleApi(Resource):
return jsonable_encoder({"data": parameter_rules}) return jsonable_encoder({"data": parameter_rules})
@console_ns.route("/workspaces/current/models/model-types/<string:model_type>")
class ModelProviderAvailableModelApi(Resource): class ModelProviderAvailableModelApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -524,32 +537,3 @@ class ModelProviderAvailableModelApi(Resource):
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
return jsonable_encoder({"data": models}) return jsonable_encoder({"data": models})
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
api.add_resource(
ModelProviderModelEnableApi,
"/workspaces/current/model-providers/<path:provider>/models/enable",
endpoint="model-provider-model-enable",
)
api.add_resource(
ModelProviderModelDisableApi,
"/workspaces/current/model-providers/<path:provider>/models/disable",
endpoint="model-provider-model-disable",
)
api.add_resource(
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
)
api.add_resource(
ModelProviderModelCredentialSwitchApi,
"/workspaces/current/model-providers/<path:provider>/models/credentials/switch",
)
api.add_resource(
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
)
api.add_resource(
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
)
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")

View File

@ -6,7 +6,7 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -19,6 +19,7 @@ from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService from services.plugin.plugin_service import PluginService
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource): class PluginDebuggingKeyApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -37,6 +38,7 @@ class PluginDebuggingKeyApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource): class PluginListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -55,6 +57,7 @@ class PluginListApi(Resource):
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource): class PluginListLatestVersionsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -72,6 +75,7 @@ class PluginListLatestVersionsApi(Resource):
return jsonable_encoder({"versions": versions}) return jsonable_encoder({"versions": versions})
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource): class PluginListInstallationsFromIdsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -91,6 +95,7 @@ class PluginListInstallationsFromIdsApi(Resource):
return jsonable_encoder({"plugins": plugins}) return jsonable_encoder({"plugins": plugins})
@console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource): class PluginIconApi(Resource):
@setup_required @setup_required
def get(self): def get(self):
@ -108,6 +113,7 @@ class PluginIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@console_ns.route("/workspaces/current/plugin/upload/pkg")
class PluginUploadFromPkgApi(Resource): class PluginUploadFromPkgApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -131,6 +137,7 @@ class PluginUploadFromPkgApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource): class PluginUploadFromGithubApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -153,6 +160,7 @@ class PluginUploadFromGithubApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/upload/bundle")
class PluginUploadFromBundleApi(Resource): class PluginUploadFromBundleApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -176,6 +184,7 @@ class PluginUploadFromBundleApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource): class PluginInstallFromPkgApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -201,6 +210,7 @@ class PluginInstallFromPkgApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource): class PluginInstallFromGithubApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -230,6 +240,7 @@ class PluginInstallFromGithubApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource): class PluginInstallFromMarketplaceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -255,6 +266,7 @@ class PluginInstallFromMarketplaceApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource): class PluginFetchMarketplacePkgApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -280,6 +292,7 @@ class PluginFetchMarketplacePkgApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource): class PluginFetchManifestApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -304,6 +317,7 @@ class PluginFetchManifestApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource): class PluginFetchInstallTasksApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -325,6 +339,7 @@ class PluginFetchInstallTasksApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>")
class PluginFetchInstallTaskApi(Resource): class PluginFetchInstallTaskApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -339,6 +354,7 @@ class PluginFetchInstallTaskApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete")
class PluginDeleteInstallTaskApi(Resource): class PluginDeleteInstallTaskApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -353,6 +369,7 @@ class PluginDeleteInstallTaskApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/delete_all")
class PluginDeleteAllInstallTaskItemsApi(Resource): class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -367,6 +384,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
class PluginDeleteInstallTaskItemApi(Resource): class PluginDeleteInstallTaskItemApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -381,6 +399,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource): class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -404,6 +423,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource): class PluginUpgradeFromGithubApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -435,6 +455,7 @@ class PluginUpgradeFromGithubApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource): class PluginUninstallApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -453,6 +474,7 @@ class PluginUninstallApi(Resource):
raise ValueError(e) raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource): class PluginChangePermissionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -475,6 +497,7 @@ class PluginChangePermissionApi(Resource):
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)} return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
@console_ns.route("/workspaces/current/plugin/permission/fetch")
class PluginFetchPermissionApi(Resource): class PluginFetchPermissionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -499,6 +522,7 @@ class PluginFetchPermissionApi(Resource):
) )
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource): class PluginFetchDynamicSelectOptionsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -535,6 +559,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options}) return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource): class PluginChangePreferencesApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -590,6 +615,7 @@ class PluginChangePreferencesApi(Resource):
return jsonable_encoder({"success": True}) return jsonable_encoder({"success": True})
@console_ns.route("/workspaces/current/plugin/preferences/fetch")
class PluginFetchPreferencesApi(Resource): class PluginFetchPreferencesApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -628,6 +654,7 @@ class PluginFetchPreferencesApi(Resource):
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict}) return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource): class PluginAutoUpgradeExcludePluginApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -641,35 +668,3 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
args = req.parse_args() args = req.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])}) return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg")
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")
api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch")
api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change")
api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude")

View File

@ -10,7 +10,7 @@ from flask_restx import (
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
enterprise_license_required, enterprise_license_required,
@ -47,6 +47,7 @@ def is_valid_url(url: str) -> bool:
return False return False
@console_ns.route("/workspaces/current/tool-providers")
class ToolProviderListApi(Resource): class ToolProviderListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -71,6 +72,7 @@ class ToolProviderListApi(Resource):
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools")
class ToolBuiltinProviderListToolsApi(Resource): class ToolBuiltinProviderListToolsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -88,6 +90,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/info")
class ToolBuiltinProviderInfoApi(Resource): class ToolBuiltinProviderInfoApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -100,6 +103,7 @@ class ToolBuiltinProviderInfoApi(Resource):
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
class ToolBuiltinProviderDeleteApi(Resource): class ToolBuiltinProviderDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -121,6 +125,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
class ToolBuiltinProviderAddApi(Resource): class ToolBuiltinProviderAddApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -150,6 +155,7 @@ class ToolBuiltinProviderAddApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
class ToolBuiltinProviderUpdateApi(Resource): class ToolBuiltinProviderUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -181,6 +187,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
return result return result
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credentials")
class ToolBuiltinProviderGetCredentialsApi(Resource): class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -196,6 +203,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
class ToolBuiltinProviderIconApi(Resource): class ToolBuiltinProviderIconApi(Resource):
@setup_required @setup_required
def get(self, provider): def get(self, provider):
@ -204,6 +212,7 @@ class ToolBuiltinProviderIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@console_ns.route("/workspaces/current/tool-provider/api/add")
class ToolApiProviderAddApi(Resource): class ToolApiProviderAddApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -243,6 +252,7 @@ class ToolApiProviderAddApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/remote")
class ToolApiProviderGetRemoteSchemaApi(Resource): class ToolApiProviderGetRemoteSchemaApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -266,6 +276,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/tools")
class ToolApiProviderListToolsApi(Resource): class ToolApiProviderListToolsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -291,6 +302,7 @@ class ToolApiProviderListToolsApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/update")
class ToolApiProviderUpdateApi(Resource): class ToolApiProviderUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -332,6 +344,7 @@ class ToolApiProviderUpdateApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/delete")
class ToolApiProviderDeleteApi(Resource): class ToolApiProviderDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -358,6 +371,7 @@ class ToolApiProviderDeleteApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/get")
class ToolApiProviderGetApi(Resource): class ToolApiProviderGetApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -381,6 +395,7 @@ class ToolApiProviderGetApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>")
class ToolBuiltinProviderCredentialsSchemaApi(Resource): class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -396,6 +411,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/schema")
class ToolApiProviderSchemaApi(Resource): class ToolApiProviderSchemaApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -412,6 +428,7 @@ class ToolApiProviderSchemaApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
class ToolApiProviderPreviousTestApi(Resource): class ToolApiProviderPreviousTestApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -439,6 +456,7 @@ class ToolApiProviderPreviousTestApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
class ToolWorkflowProviderCreateApi(Resource): class ToolWorkflowProviderCreateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -478,6 +496,7 @@ class ToolWorkflowProviderCreateApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
class ToolWorkflowProviderUpdateApi(Resource): class ToolWorkflowProviderUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -520,6 +539,7 @@ class ToolWorkflowProviderUpdateApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
class ToolWorkflowProviderDeleteApi(Resource): class ToolWorkflowProviderDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -545,6 +565,7 @@ class ToolWorkflowProviderDeleteApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
class ToolWorkflowProviderGetApi(Resource): class ToolWorkflowProviderGetApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -579,6 +600,7 @@ class ToolWorkflowProviderGetApi(Resource):
return jsonable_encoder(tool) return jsonable_encoder(tool)
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
class ToolWorkflowProviderListToolApi(Resource): class ToolWorkflowProviderListToolApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -603,6 +625,7 @@ class ToolWorkflowProviderListToolApi(Resource):
) )
@console_ns.route("/workspaces/current/tools/builtin")
class ToolBuiltinListApi(Resource): class ToolBuiltinListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -624,6 +647,7 @@ class ToolBuiltinListApi(Resource):
) )
@console_ns.route("/workspaces/current/tools/api")
class ToolApiListApi(Resource): class ToolApiListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -642,6 +666,7 @@ class ToolApiListApi(Resource):
) )
@console_ns.route("/workspaces/current/tools/workflow")
class ToolWorkflowListApi(Resource): class ToolWorkflowListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -663,6 +688,7 @@ class ToolWorkflowListApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-labels")
class ToolLabelsApi(Resource): class ToolLabelsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -672,6 +698,7 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels()) return jsonable_encoder(ToolLabelsService.list_tool_labels())
@console_ns.route("/oauth/plugin/<path:provider>/tool/authorization-url")
class ToolPluginOAuthApi(Resource): class ToolPluginOAuthApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -716,6 +743,7 @@ class ToolPluginOAuthApi(Resource):
return response return response
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
class ToolOAuthCallback(Resource): class ToolOAuthCallback(Resource):
@setup_required @setup_required
def get(self, provider): def get(self, provider):
@ -766,6 +794,7 @@ class ToolOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
class ToolBuiltinProviderSetDefaultApi(Resource): class ToolBuiltinProviderSetDefaultApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -779,6 +808,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
class ToolOAuthCustomClient(Resource): class ToolOAuthCustomClient(Resource):
@setup_required @setup_required
@login_required @login_required
@ -822,6 +852,7 @@ class ToolOAuthCustomClient(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema")
class ToolBuiltinProviderGetOauthClientSchemaApi(Resource): class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -834,6 +865,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/credential/info")
class ToolBuiltinProviderGetCredentialInfoApi(Resource): class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -849,6 +881,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
) )
@console_ns.route("/workspaces/current/tool-provider/mcp")
class ToolProviderMCPApi(Resource): class ToolProviderMCPApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -933,6 +966,7 @@ class ToolProviderMCPApi(Resource):
return {"result": "success"} return {"result": "success"}
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
class ToolMCPAuthApi(Resource): class ToolMCPAuthApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -978,6 +1012,7 @@ class ToolMCPAuthApi(Resource):
raise ValueError(f"Failed to connect to MCP server: {e}") from e raise ValueError(f"Failed to connect to MCP server: {e}") from e
@console_ns.route("/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
class ToolMCPDetailApi(Resource): class ToolMCPDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -988,6 +1023,7 @@ class ToolMCPDetailApi(Resource):
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@console_ns.route("/workspaces/current/tools/mcp")
class ToolMCPListAllApi(Resource): class ToolMCPListAllApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -1001,6 +1037,7 @@ class ToolMCPListAllApi(Resource):
return [tool.to_dict() for tool in tools] return [tool.to_dict() for tool in tools]
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
class ToolMCPUpdateApi(Resource): class ToolMCPUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -1014,6 +1051,7 @@ class ToolMCPUpdateApi(Resource):
return jsonable_encoder(tools) return jsonable_encoder(tools)
@console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource): class ToolMCPCallbackApi(Resource):
def get(self): def get(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -1024,67 +1062,3 @@ class ToolMCPCallbackApi(Resource):
authorization_code = args["code"] authorization_code = args["code"]
handle_callback(state_key, authorization_code) handle_callback(state_key, authorization_code)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# tool oauth
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin/<path:provider>/add")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
)
api.add_resource(
ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credential/info"
)
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>",
)
api.add_resource(
ToolBuiltinProviderGetOauthClientSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema",
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
# api tool provider
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote")
api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools")
api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update")
api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete")
api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get")
api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema")
api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre")
# workflow tool provider
api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create")
api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update")
api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete")
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
# mcp tool provider
api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback")
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp")
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")

View File

@ -14,7 +14,7 @@ from controllers.common.errors import (
TooManyFilesError, TooManyFilesError,
UnsupportedFileTypeError, UnsupportedFileTypeError,
) )
from controllers.console import api from controllers.console import console_ns
from controllers.console.admin import admin_required from controllers.console.admin import admin_required
from controllers.console.error import AccountNotLinkTenantError from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import ( from controllers.console.wraps import (
@ -65,6 +65,7 @@ tenants_fields = {
workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField} workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField}
@console_ns.route("/workspaces")
class TenantListApi(Resource): class TenantListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -93,6 +94,7 @@ class TenantListApi(Resource):
return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200 return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200
@console_ns.route("/all-workspaces")
class WorkspaceListApi(Resource): class WorkspaceListApi(Resource):
@setup_required @setup_required
@admin_required @admin_required
@ -118,6 +120,8 @@ class WorkspaceListApi(Resource):
}, 200 }, 200
@console_ns.route("/workspaces/current", endpoint="workspaces_current")
@console_ns.route("/info", endpoint="info") # Deprecated
class TenantApi(Resource): class TenantApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -143,11 +147,10 @@ class TenantApi(Resource):
else: else:
raise Unauthorized("workspace is archived") raise Unauthorized("workspace is archived")
if not tenant:
raise ValueError("No tenant available")
return WorkspaceService.get_tenant_info(tenant), 200 return WorkspaceService.get_tenant_info(tenant), 200
@console_ns.route("/workspaces/switch")
class SwitchWorkspaceApi(Resource): class SwitchWorkspaceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -172,6 +175,7 @@ class SwitchWorkspaceApi(Resource):
return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
@console_ns.route("/workspaces/custom-config")
class CustomConfigWorkspaceApi(Resource): class CustomConfigWorkspaceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -202,6 +206,7 @@ class CustomConfigWorkspaceApi(Resource):
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
@console_ns.route("/workspaces/custom-config/webapp-logo/upload")
class WebappLogoWorkspaceApi(Resource): class WebappLogoWorkspaceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -242,6 +247,7 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201 return {"id": upload_file.id}, 201
@console_ns.route("/workspaces/info")
class WorkspaceInfoApi(Resource): class WorkspaceInfoApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -261,13 +267,3 @@ class WorkspaceInfoApi(Resource):
db.session.commit() db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants
api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants
api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info
api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated
api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant
api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")
api.add_resource(WorkspaceInfoApi, "/workspaces/info") # POST for changing workspace info

View File

@ -24,20 +24,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
NOTE: user_id is not trusted, it could be maliciously set to any value. NOTE: user_id is not trusted, it could be maliciously set to any value.
As a result, it could only be considered as an end user id. As a result, it could only be considered as an end user id.
""" """
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
try: try:
with Session(db.engine) as session: with Session(db.engine) as session:
if not user_id: user_model = None
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_model = ( if is_anonymous:
session.query(EndUser)
.where(
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
if not user_model:
user_model = ( user_model = (
session.query(EndUser) session.query(EndUser)
.where( .where(
@ -46,11 +40,21 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
) )
.first() .first()
) )
else:
user_model = (
session.query(EndUser)
.where(
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
if not user_model: if not user_model:
user_model = EndUser( user_model = EndUser(
tenant_id=tenant_id, tenant_id=tenant_id,
type="service_api", type="service_api",
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value, is_anonymous=is_anonymous,
session_id=user_id, session_id=user_id,
) )
session.add(user_model) session.add(user_model)
@ -81,7 +85,7 @@ def get_user_tenant(view: Callable[P, R] | None = None):
raise ValueError("tenant_id is required") raise ValueError("tenant_id is required")
if not user_id: if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
try: try:
tenant_model = ( tenant_model = (
@ -124,7 +128,7 @@ def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseMo
raise ValueError("invalid json") raise ValueError("invalid json")
try: try:
payload = payload_type(**data) payload = payload_type.model_validate(data)
except Exception as e: except Exception as e:
raise ValueError(f"invalid payload: {str(e)}") raise ValueError(f"invalid payload: {str(e)}")

View File

@ -1,10 +1,10 @@
from typing import Literal from typing import Any, Literal, cast
from flask import request from flask import request
from flask_restx import marshal, reqparse from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service import services
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
from controllers.service_api.wraps import ( from controllers.service_api.wraps import (
@ -17,6 +17,7 @@ from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user from libs.login import current_user
from libs.validators import validate_description_length
from models.account import Account from models.account import Account
from models.dataset import Dataset, DatasetPermissionEnum from models.dataset import Dataset, DatasetPermissionEnum
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
@ -31,12 +32,6 @@ def _validate_name(name):
return name return name
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
# Define parsers for dataset operations # Define parsers for dataset operations
dataset_create_parser = reqparse.RequestParser() dataset_create_parser = reqparse.RequestParser()
dataset_create_parser.add_argument( dataset_create_parser.add_argument(
@ -48,7 +43,7 @@ dataset_create_parser.add_argument(
) )
dataset_create_parser.add_argument( dataset_create_parser.add_argument(
"description", "description",
type=_validate_description_length, type=validate_description_length,
nullable=True, nullable=True,
required=False, required=False,
default="", default="",
@ -101,7 +96,7 @@ dataset_update_parser.add_argument(
type=_validate_name, type=_validate_name,
) )
dataset_update_parser.add_argument( dataset_update_parser.add_argument(
"description", location="json", store_missing=False, type=_validate_description_length "description", location="json", store_missing=False, type=validate_description_length
) )
dataset_update_parser.add_argument( dataset_update_parser.add_argument(
"indexing_technique", "indexing_technique",
@ -254,19 +249,21 @@ class DatasetListApi(DatasetApiResource):
"""Resource for creating datasets.""" """Resource for creating datasets."""
args = dataset_create_parser.parse_args() args = dataset_create_parser.parse_args()
if args.get("embedding_model_provider"): embedding_model_provider = args.get("embedding_model_provider")
DatasetService.check_embedding_model_setting( embedding_model = args.get("embedding_model")
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") if embedding_model_provider and embedding_model:
) DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model")
if ( if (
args.get("retrieval_model") retrieval_model
and args.get("retrieval_model").get("reranking_model") and retrieval_model.get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") and retrieval_model.get("reranking_model").get("reranking_provider_name")
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), retrieval_model.get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), retrieval_model.get("reranking_model").get("reranking_model_name"),
) )
try: try:
@ -283,7 +280,7 @@ class DatasetListApi(DatasetApiResource):
external_knowledge_id=args["external_knowledge_id"], external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"], embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"], embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel(**args["retrieval_model"]) retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
if args["retrieval_model"] is not None if args["retrieval_model"] is not None
else None, else None,
) )
@ -317,7 +314,7 @@ class DatasetApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
@ -331,8 +328,8 @@ class DatasetApi(DatasetApiResource):
for embedding_model in embedding_models: for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data["indexing_technique"] == "high_quality": if data.get("indexing_technique") == "high_quality":
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
if item_model in model_names: if item_model in model_names:
data["embedding_available"] = True data["embedding_available"] = True
else: else:
@ -341,7 +338,9 @@ class DatasetApi(DatasetApiResource):
data["embedding_available"] = True data["embedding_available"] = True
# force update search method to keyword_search if indexing_technique is economic # force update search method to keyword_search if indexing_technique is economic
data["retrieval_model_dict"]["search_method"] = "keyword_search" retrieval_model_dict = data.get("retrieval_model_dict")
if retrieval_model_dict:
retrieval_model_dict["search_method"] = "keyword_search"
if data.get("permission") == "partial_members": if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@ -372,19 +371,24 @@ class DatasetApi(DatasetApiResource):
data = request.get_json() data = request.get_json()
# check embedding model setting # check embedding model setting
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"): embedding_model_provider = data.get("embedding_model_provider")
DatasetService.check_embedding_model_setting( embedding_model = data.get("embedding_model")
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
) if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(
dataset.tenant_id, embedding_model_provider, embedding_model
)
retrieval_model = data.get("retrieval_model")
if ( if (
data.get("retrieval_model") retrieval_model
and data.get("retrieval_model").get("reranking_model") and retrieval_model.get("reranking_model")
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name") and retrieval_model.get("reranking_model").get("reranking_provider_name")
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
dataset.tenant_id, dataset.tenant_id,
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), retrieval_model.get("reranking_model").get("reranking_provider_name"),
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"), retrieval_model.get("reranking_model").get("reranking_model_name"),
) )
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
@ -397,7 +401,7 @@ class DatasetApi(DatasetApiResource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields) result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -591,9 +595,10 @@ class DatasetTagsApi(DatasetApiResource):
args = tag_update_parser.parse_args() args = tag_update_parser.parse_args()
args["type"] = "knowledge" args["type"] = "knowledge"
tag = TagService.update_tags(args, args.get("tag_id")) tag_id = args["tag_id"]
tag = TagService.update_tags(args, tag_id)
binding_count = TagService.get_tag_binding_count(args.get("tag_id")) binding_count = TagService.get_tag_binding_count(tag_id)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
@ -616,7 +621,7 @@ class DatasetTagsApi(DatasetApiResource):
if not current_user.has_edit_permission: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
args = tag_delete_parser.parse_args() args = tag_delete_parser.parse_args()
TagService.delete_tag(args.get("tag_id")) TagService.delete_tag(args["tag_id"])
return 204 return 204

View File

@ -30,7 +30,6 @@ from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from models.model import EndUser
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService from services.file_service import FileService
@ -109,19 +108,21 @@ class DocumentAddByTextApi(DatasetApiResource):
if text is None or name is None: if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.") raise ValueError("Both 'text' and 'name' must be non-null values.")
if args.get("embedding_model_provider"): embedding_model_provider = args.get("embedding_model_provider")
DatasetService.check_embedding_model_setting( embedding_model = args.get("embedding_model")
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") if embedding_model_provider and embedding_model:
) DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model")
if ( if (
args.get("retrieval_model") retrieval_model
and args.get("retrieval_model").get("reranking_model") and retrieval_model.get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") and retrieval_model.get("reranking_model").get("reranking_provider_name")
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), retrieval_model.get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), retrieval_model.get("reranking_model").get("reranking_model_name"),
) )
if not current_user: if not current_user:
@ -135,7 +136,7 @@ class DocumentAddByTextApi(DatasetApiResource):
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
} }
args["data_source"] = data_source args["data_source"] = data_source
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
# validate args # validate args
DocumentService.document_create_args_validate(knowledge_config) DocumentService.document_create_args_validate(knowledge_config)
@ -188,15 +189,16 @@ class DocumentUpdateByTextApi(DatasetApiResource):
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
retrieval_model = args.get("retrieval_model")
if ( if (
args.get("retrieval_model") retrieval_model
and args.get("retrieval_model").get("reranking_model") and retrieval_model.get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") and retrieval_model.get("reranking_model").get("reranking_provider_name")
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), retrieval_model.get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), retrieval_model.get("reranking_model").get("reranking_model_name"),
) )
# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update
@ -219,7 +221,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
args["original_document_id"] = str(document_id) args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config) DocumentService.document_create_args_validate(knowledge_config)
try: try:
@ -311,8 +313,6 @@ class DocumentAddByFileApi(DatasetApiResource):
if not file.filename: if not file.filename:
raise FilenameNotExistsError raise FilenameNotExistsError
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
if not current_user: if not current_user:
raise ValueError("current_user is required") raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_file( upload_file = FileService(db.engine).upload_file(
@ -328,7 +328,7 @@ class DocumentAddByFileApi(DatasetApiResource):
} }
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config) DocumentService.document_create_args_validate(knowledge_config)
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
@ -406,9 +406,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if not current_user: if not current_user:
raise ValueError("current_user is required") raise ValueError("current_user is required")
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
try: try:
upload_file = FileService(db.engine).upload_file( upload_file = FileService(db.engine).upload_file(
filename=file.filename, filename=file.filename,
@ -429,7 +426,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
# validate args # validate args
args["original_document_id"] = str(document_id) args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config) DocumentService.document_create_args_validate(knowledge_config)
try: try:

View File

@ -51,7 +51,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset.""" """Create metadata for a dataset."""
args = metadata_create_parser.parse_args() args = metadata_create_parser.parse_args()
metadata_args = MetadataArgs(**args) metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -106,7 +106,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
return marshal(metadata, dataset_metadata_fields), 200 return marshal(metadata, dataset_metadata_fields), 200
@service_api_ns.doc("delete_dataset_metadata") @service_api_ns.doc("delete_dataset_metadata")
@ -200,7 +200,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
args = document_metadata_parser.parse_args() args = document_metadata_parser.parse_args()
metadata_args = MetadataOperationData(**args) metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args) MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -98,7 +98,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
parser.add_argument("is_published", type=bool, required=True, location="json") parser.add_argument("is_published", type=bool, required=True, location="json")
args: ParseResult = parser.parse_args() args: ParseResult = parser.parse_args()
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args) datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService() rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id) pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)

View File

@ -252,7 +252,7 @@ class DatasetSegmentApi(DatasetApiResource):
args = segment_update_parser.parse_args() args = segment_update_parser.parse_args()
updated_segment = SegmentService.update_segment( updated_segment = SegmentService.update_segment(
SegmentUpdateArgs(**args["segment"]), segment, document, dataset SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
) )
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200

View File

@ -313,7 +313,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
Create or update session terminal based on user ID. Create or update session terminal based on user ID.
""" """
if not user_id: if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
end_user = ( end_user = (
@ -332,7 +332,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
app_id=app_model.id, app_id=app_model.id,
type="service_api", type="service_api",
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value, is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID,
session_id=user_id, session_id=user_id,
) )
session.add(end_user) session.add(end_user)

View File

@ -126,6 +126,8 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user_id = enterprise_user_decoded.get("end_user_id") end_user_id = enterprise_user_decoded.get("end_user_id")
session_id = enterprise_user_decoded.get("session_id") session_id = enterprise_user_decoded.get("session_id")
user_auth_type = enterprise_user_decoded.get("auth_type") user_auth_type = enterprise_user_decoded.get("auth_type")
exchanged_token_expires_unix = enterprise_user_decoded.get("exp")
if not user_auth_type: if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.") raise Unauthorized("Missing auth_type in the token.")
@ -169,8 +171,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
) )
db.session.add(end_user) db.session.add(end_user)
db.session.commit() db.session.commit()
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
exp = int(exp_dt.timestamp()) exp = int((datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)).timestamp())
if exchanged_token_expires_unix:
exp = int(exchanged_token_expires_unix)
payload = { payload = {
"iss": site.id, "iss": site.id,
"sub": "Web API Passport", "sub": "Web API Passport",

View File

@ -40,7 +40,7 @@ class AgentConfigManager:
"credential_id": tool.get("credential_id", None), "credential_id": tool.get("credential_id", None),
} }
agent_tools.append(AgentToolEntity(**agent_tool_properties)) agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in { if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
"react_router", "react_router",

View File

@ -1,4 +1,5 @@
import uuid import uuid
from typing import Literal, cast
from core.app.app_config.entities import ( from core.app.app_config.entities import (
DatasetEntity, DatasetEntity,
@ -74,6 +75,9 @@ class DatasetConfigManager:
return None return None
query_variable = config.get("dataset_query_variable") query_variable = config.get("dataset_query_variable")
metadata_model_config_dict = dataset_configs.get("metadata_model_config")
metadata_filtering_conditions_dict = dataset_configs.get("metadata_filtering_conditions")
if dataset_configs["retrieval_model"] == "single": if dataset_configs["retrieval_model"] == "single":
return DatasetEntity( return DatasetEntity(
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
@ -82,18 +86,23 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"] dataset_configs["retrieval_model"]
), ),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"), metadata_filtering_mode=cast(
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config")) Literal["disabled", "automatic", "manual"],
if dataset_configs.get("metadata_model_config") dataset_configs.get("metadata_filtering_mode", "disabled"),
),
metadata_model_config=ModelConfig(**metadata_model_config_dict)
if isinstance(metadata_model_config_dict, dict)
else None, else None,
metadata_filtering_conditions=MetadataFilteringCondition( metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
**dataset_configs.get("metadata_filtering_conditions", {}) if isinstance(metadata_filtering_conditions_dict, dict)
)
if dataset_configs.get("metadata_filtering_conditions")
else None, else None,
), ),
) )
else: else:
score_threshold_val = dataset_configs.get("score_threshold")
reranking_model_val = dataset_configs.get("reranking_model")
weights_val = dataset_configs.get("weights")
return DatasetEntity( return DatasetEntity(
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity( retrieve_config=DatasetRetrieveConfigEntity(
@ -101,22 +110,23 @@ class DatasetConfigManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"] dataset_configs["retrieval_model"]
), ),
top_k=dataset_configs.get("top_k", 4), top_k=int(dataset_configs.get("top_k", 4)),
score_threshold=dataset_configs.get("score_threshold") score_threshold=float(score_threshold_val)
if dataset_configs.get("score_threshold_enabled", False) if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
else None, else None,
reranking_model=dataset_configs.get("reranking_model"), reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
weights=dataset_configs.get("weights"), weights=weights_val if isinstance(weights_val, dict) else None,
reranking_enabled=dataset_configs.get("reranking_enabled", True), reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"), metadata_filtering_mode=cast(
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config")) Literal["disabled", "automatic", "manual"],
if dataset_configs.get("metadata_model_config") dataset_configs.get("metadata_filtering_mode", "disabled"),
),
metadata_model_config=ModelConfig(**metadata_model_config_dict)
if isinstance(metadata_model_config_dict, dict)
else None, else None,
metadata_filtering_conditions=MetadataFilteringCondition( metadata_filtering_conditions=MetadataFilteringCondition(**metadata_filtering_conditions_dict)
**dataset_configs.get("metadata_filtering_conditions", {}) if isinstance(metadata_filtering_conditions_dict, dict)
)
if dataset_configs.get("metadata_filtering_conditions")
else None, else None,
), ),
) )
@ -134,18 +144,17 @@ class DatasetConfigManager:
config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config) config = cls.extract_dataset_config_for_legacy_compatibility(tenant_id, app_mode, config)
# dataset_configs # dataset_configs
if not config.get("dataset_configs"): if "dataset_configs" not in config or not config.get("dataset_configs"):
config["dataset_configs"] = {"retrieval_model": "single"} config["dataset_configs"] = {}
config["dataset_configs"]["retrieval_model"] = config["dataset_configs"].get("retrieval_model", "single")
if not isinstance(config["dataset_configs"], dict): if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type") raise ValueError("dataset_configs must be of object type")
if not config["dataset_configs"].get("datasets"): if "datasets" not in config["dataset_configs"] or not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []} config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( need_manual_query_datasets = config.get("dataset_configs", {}).get("datasets", {}).get("datasets")
"datasets", {}
).get("datasets")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION: if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion # Only check when mode is completion
@ -166,8 +175,8 @@ class DatasetConfigManager:
:param config: app model config args :param config: app model config args
""" """
# Extract dataset config for legacy compatibility # Extract dataset config for legacy compatibility
if not config.get("agent_mode"): if "agent_mode" not in config or not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []} config["agent_mode"] = {}
if not isinstance(config["agent_mode"], dict): if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type") raise ValueError("agent_mode must be of object type")
@ -180,19 +189,22 @@ class DatasetConfigManager:
raise ValueError("enabled in agent_mode must be of boolean type") raise ValueError("enabled in agent_mode must be of boolean type")
# tools # tools
if not config["agent_mode"].get("tools"): if "tools" not in config["agent_mode"] or not config["agent_mode"].get("tools"):
config["agent_mode"]["tools"] = [] config["agent_mode"]["tools"] = []
if not isinstance(config["agent_mode"]["tools"], list): if not isinstance(config["agent_mode"]["tools"], list):
raise ValueError("tools in agent_mode must be a list of objects") raise ValueError("tools in agent_mode must be a list of objects")
# strategy # strategy
if not config["agent_mode"].get("strategy"): if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER
has_datasets = False has_datasets = False
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}: if config.get("agent_mode", {}).get("strategy") in {
for tool in config["agent_mode"]["tools"]: PlanningStrategy.ROUTER,
PlanningStrategy.REACT_ROUTER,
}:
for tool in config.get("agent_mode", {}).get("tools", []):
key = list(tool.keys())[0] key = list(tool.keys())[0]
if key == "dataset": if key == "dataset":
# old style, use tool name as key # old style, use tool name as key
@ -217,7 +229,7 @@ class DatasetConfigManager:
has_datasets = True has_datasets = True
need_manual_query_datasets = has_datasets and config["agent_mode"]["enabled"] need_manual_query_datasets = has_datasets and config.get("agent_mode", {}).get("enabled")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION: if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion # Only check when mode is completion

View File

@ -68,7 +68,7 @@ class ModelConfigConverter:
# get model mode # get model mode
model_mode = model_config.mode model_mode = model_config.mode
if not model_mode: if not model_mode:
model_mode = LLMMode.CHAT.value model_mode = LLMMode.CHAT
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE): if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value

View File

@ -100,7 +100,7 @@ class PromptTemplateConfigManager:
if config["model"]["mode"] not in model_mode_vals: if config["model"]["mode"] not in model_mode_vals:
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value: if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION:
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
@ -110,7 +110,7 @@ class PromptTemplateConfigManager:
if not assistant_prefix: if not assistant_prefix:
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant" config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
if config["model"]["mode"] == ModelMode.CHAT.value: if config["model"]["mode"] == ModelMode.CHAT:
prompt_list = config["chat_prompt_config"]["prompt"] prompt_list = config["chat_prompt_config"]["prompt"]
if len(prompt_list) > 10: if len(prompt_list) > 10:

View File

@ -79,29 +79,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")
if self.application_generate_entity.single_iteration_run: if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# if only single iteration run is requested # Handle single iteration or single loop run
graph_runtime_state = GraphRuntimeState( graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow, workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, single_iteration_run=self.application_generate_entity.single_iteration_run,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), single_loop_run=self.application_generate_entity.single_loop_run,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs

View File

@ -551,7 +551,7 @@ class AdvancedChatAppGenerateTaskPipeline:
total_steps=validated_state.node_run_steps, total_steps=validated_state.node_run_steps,
outputs=event.outputs, outputs=event.outputs,
exceptions_count=event.exceptions_count, exceptions_count=event.exceptions_count,
conversation_id=None, conversation_id=self._conversation_id,
trace_manager=trace_manager, trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
) )

View File

@ -186,7 +186,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
raise ValueError("enabled in agent_mode must be of boolean type") raise ValueError("enabled in agent_mode must be of boolean type")
if not agent_mode.get("strategy"): if not agent_mode.get("strategy"):
agent_mode["strategy"] = PlanningStrategy.ROUTER.value agent_mode["strategy"] = PlanningStrategy.ROUTER
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list") raise ValueError("strategy in agent_mode must be in the specified strategy list")

View File

@ -198,9 +198,9 @@ class AgentChatAppRunner(AppRunner):
# start agent runner # start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode # check LLM mode
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
runner_cls = CotChatAgentRunner runner_cls = CotChatAgentRunner
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value: elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
runner_cls = CotCompletionAgentRunner runner_cls = CotCompletionAgentRunner
else: else:
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}") raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")

View File

@ -1,9 +1,11 @@
import logging
import queue import queue
import time import time
from abc import abstractmethod from abc import abstractmethod
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Any from typing import Any
from redis.exceptions import RedisError
from sqlalchemy.orm import DeclarativeMeta from sqlalchemy.orm import DeclarativeMeta
from configs import dify_config from configs import dify_config
@ -18,6 +20,8 @@ from core.app.entities.queue_entities import (
) )
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class PublishFrom(IntEnum): class PublishFrom(IntEnum):
APPLICATION_MANAGER = auto() APPLICATION_MANAGER = auto()
@ -35,9 +39,8 @@ class AppQueueManager:
self.invoke_from = invoke_from # Public accessor for invoke_from self.invoke_from = invoke_from # Public accessor for invoke_from
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex( self._task_belong_cache_key = AppQueueManager._generate_task_belong_cache_key(self._task_id)
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" redis_client.setex(self._task_belong_cache_key, 1800, f"{user_prefix}-{self._user_id}")
)
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
@ -79,9 +82,21 @@ class AppQueueManager:
Stop listen to queue Stop listen to queue
:return: :return:
""" """
self._clear_task_belong_cache()
self._q.put(None) self._q.put(None)
def publish_error(self, e, pub_from: PublishFrom): def _clear_task_belong_cache(self) -> None:
"""
Remove the task belong cache key once listening is finished.
"""
try:
redis_client.delete(self._task_belong_cache_key)
except RedisError:
logger.exception(
"Failed to clear task belong cache for task %s (key: %s)", self._task_id, self._task_belong_cache_key
)
def publish_error(self, e, pub_from: PublishFrom) -> None:
""" """
Publish error Publish error
:param e: error :param e: error

View File

@ -61,9 +61,6 @@ class AppRunner:
if model_context_tokens is None: if model_context_tokens is None:
return -1 return -1
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_context_tokens: if prompt_tokens + max_tokens > model_context_tokens:

View File

@ -427,6 +427,9 @@ class PipelineGenerator(BaseAppGenerator):
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
call_depth=0, call_depth=0,
workflow_execution_id=str(uuid.uuid4()), workflow_execution_id=str(uuid.uuid4()),
single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
) )
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -465,6 +468,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
variable_loader=var_loader, variable_loader=var_loader,
context=contextvars.copy_context(),
) )
def single_loop_generate( def single_loop_generate(
@ -559,6 +563,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
variable_loader=var_loader, variable_loader=var_loader,
context=contextvars.copy_context(),
) )
def _generate_worker( def _generate_worker(

View File

@ -86,29 +86,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
db.session.close() db.session.close()
# if only single iteration run is requested # if only single iteration run is requested
if self.application_generate_entity.single_iteration_run: if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph_runtime_state = GraphRuntimeState( # Handle single iteration or single loop run
variable_pool=VariablePool.empty(), graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
start_at=time.time(),
)
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow, workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, single_iteration_run=self.application_generate_entity.single_iteration_run,
user_inputs=self.application_generate_entity.single_iteration_run.inputs, single_loop_run=self.application_generate_entity.single_loop_run,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs
@ -133,7 +116,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
rag_pipeline_variables = [] rag_pipeline_variables = []
if workflow.rag_pipeline_variables: if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables: for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v) rag_pipeline_variable = RAGPipelineVariable.model_validate(v)
if ( if (
rag_pipeline_variable.belong_to_node_id rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared") in (self.application_generate_entity.start_node_id, "shared")
@ -246,8 +229,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
workflow_id=workflow.id, workflow_id=workflow.id,
graph_config=graph_config, graph_config=graph_config,
user_id=self.application_generate_entity.user_id, user_id=self.application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT.value, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API.value, invoke_from=InvokeFrom.SERVICE_API,
call_depth=0, call_depth=0,
) )

View File

@ -51,30 +51,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config) app_config = cast(WorkflowAppConfig, app_config)
# if only single iteration run is requested # if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run: if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# if only single iteration run is requested graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow, workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, single_iteration_run=self.application_generate_entity.single_iteration_run,
user_inputs=self.application_generate_entity.single_iteration_run.inputs, single_loop_run=self.application_generate_entity.single_loop_run,
graph_runtime_state=graph_runtime_state,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool.empty(),
start_at=time.time(),
)
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
graph_runtime_state=graph_runtime_state,
) )
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs

View File

@ -1,3 +1,4 @@
import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, cast from typing import Any, cast
@ -99,8 +100,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow_id, workflow_id=workflow_id,
graph_config=graph_config, graph_config=graph_config,
user_id=user_id, user_id=user_id,
user_from=UserFrom.ACCOUNT.value, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API.value, invoke_from=InvokeFrom.SERVICE_API,
call_depth=0, call_depth=0,
) )
@ -119,15 +120,81 @@ class WorkflowBasedAppRunner:
return graph return graph
def _get_graph_and_variable_pool_of_single_iteration( def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
(either single iteration or single loop).
Args:
workflow: The workflow instance
single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise
single_loop_run: SingleLoopRunEntity if running single loop, None otherwise
Returns:
A tuple containing (graph, variable_pool, graph_runtime_state)
Raises:
ValueError: If neither single_iteration_run nor single_loop_run is specified
"""
# Create initial runtime state with variable pool containing environment variables
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
),
start_at=time.time(),
)
# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
)
elif single_loop_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
# Return the graph, variable_pool, and the same graph_runtime_state used during graph creation
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state
def _get_graph_and_variable_pool_for_single_node_run(
self, self,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user_inputs: dict, user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
) -> tuple[Graph, VariablePool]: ) -> tuple[Graph, VariablePool]:
""" """
Get variable pool of single iteration Get graph and variable pool for single node execution (iteration or loop).
Args:
workflow: The workflow instance
node_id: The node ID to execute
user_inputs: User inputs for the node
graph_runtime_state: The graph runtime state
node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id')
node_type_label: Label for error messages ('iteration' or 'loop')
Returns:
A tuple containing (graph, variable_pool)
""" """
# fetch workflow graph # fetch workflow graph
graph_config = workflow.graph_dict graph_config = workflow.graph_dict
@ -145,18 +212,22 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list): if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list") raise ValueError("edges in workflow graph must be a list")
# filter nodes only in iteration # filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [ node_configs = [
node node
for node in graph_config.get("nodes", []) for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
] ]
graph_config["nodes"] = node_configs graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs] node_ids = [node.get("id") for node in node_configs]
# filter edges only in iteration # filter edges only in the specified node type
edge_configs = [ edge_configs = [
edge edge
for edge in graph_config.get("edges", []) for edge in graph_config.get("edges", [])
@ -173,8 +244,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow.id, workflow_id=workflow.id,
graph_config=graph_config, graph_config=graph_config,
user_id="", user_id="",
user_from=UserFrom.ACCOUNT.value, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API.value, invoke_from=InvokeFrom.SERVICE_API,
call_depth=0, call_depth=0,
) )
@ -190,30 +261,26 @@ class WorkflowBasedAppRunner:
raise ValueError("graph not found in workflow") raise ValueError("graph not found in workflow")
# fetch node config from node id # fetch node config from node id
iteration_node_config = None target_node_config = None
for node in node_configs: for node in node_configs:
if node.get("id") == node_id: if node.get("id") == node_id:
iteration_node_config = node target_node_config = node
break break
if not iteration_node_config: if not target_node_config:
raise ValueError("iteration node id not found in workflow graph") raise ValueError(f"{node_type_label} node id not found in workflow graph")
# Get node class # Get node class
node_type = NodeType(iteration_node_config.get("data", {}).get("type")) node_type = NodeType(target_node_config.get("data", {}).get("type"))
node_version = iteration_node_config.get("data", {}).get("version", "1") node_version = target_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool # Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = VariablePool( variable_pool = graph_runtime_state.variable_pool
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
try: try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=iteration_node_config graph_config=workflow.graph_dict, config=target_node_config
) )
except NotImplementedError: except NotImplementedError:
variable_mapping = {} variable_mapping = {}
@ -234,120 +301,44 @@ class WorkflowBasedAppRunner:
return graph, variable_pool return graph, variable_pool
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
def _get_graph_and_variable_pool_of_single_loop( def _get_graph_and_variable_pool_of_single_loop(
self, self,
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user_inputs: dict, user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]: ) -> tuple[Graph, VariablePool]:
""" """
Get variable pool of single loop Get variable pool of single loop
""" """
# fetch workflow graph return self._get_graph_and_variable_pool_for_single_node_run(
graph_config = workflow.graph_dict workflow=workflow,
if not graph_config: node_id=node_id,
raise ValueError("workflow graph not found") user_inputs=user_inputs,
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in loop
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in loop
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
node_type_filter_key="loop_id",
node_type_label="loop",
) )
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
loop_node_config = None
for node in node_configs:
if node.get("id") == node_id:
loop_node_config = node
break
if not loop_node_config:
raise ValueError("loop node id not found in workflow graph")
# Get node class
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
node_version = loop_node_config.get("data", {}).get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=loop_node_config
)
except NotImplementedError:
variable_mapping = {}
load_into_variable_pool(
self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
""" """
Handle event Handle event

View File

@ -107,7 +107,6 @@ class MessageCycleManager:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("generate conversation name failed, conversation_id: %s", conversation_id) logger.exception("generate conversation name failed, conversation_id: %s", conversation_id)
db.session.merge(conversation)
db.session.commit() db.session.commit()
db.session.close() db.session.close()

View File

@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from openai import BaseModel from pydantic import BaseModel, Field
from pydantic import Field
# Import InvokeFrom locally to avoid circular import # Import InvokeFrom locally to avoid circular import
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom

View File

@ -49,7 +49,7 @@ class DatasourceProviderApiEntity(BaseModel):
for datasource in datasources: for datasource in datasources:
if datasource.get("parameters"): if datasource.get("parameters"):
for parameter in datasource.get("parameters"): for parameter in datasource.get("parameters"):
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value: if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES:
parameter["type"] = "files" parameter["type"] = "files"
# ------------- # -------------

Some files were not shown because too many files have changed in this diff Show More