This commit is contained in:
Joel 2025-11-17 15:42:56 +08:00
commit 9a07488da9
1423 changed files with 85420 additions and 11505 deletions

View File

@ -11,7 +11,7 @@
"nodeGypDependencies": true,
"version": "lts"
},
"ghcr.io/devcontainers-contrib/features/npm-package:1": {
"ghcr.io/devcontainers-extra/features/npm-package:1": {
"package": "typescript",
"version": "latest"
},

View File

@ -6,11 +6,10 @@ cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
source /home/vscode/.bashrc

View File

@ -2,6 +2,8 @@ name: autofix.ci
on:
pull_request:
branches: ["main"]
push:
branches: ["main"]
permissions:
contents: read

View File

@ -103,6 +103,11 @@ jobs:
run: |
pnpm run lint
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run type-check
docker-compose-template:
name: Docker Compose Template
runs-on: ubuntu-latest

9
.gitignore vendored
View File

@ -6,6 +6,9 @@ __pycache__/
# C extensions
*.so
# *db files
*.db
# Distribution / packaging
.Python
build/
@ -97,6 +100,7 @@ __pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat-schedule.db
celerybeat.pid
# SageMath parsed files
@ -234,4 +238,7 @@ scripts/stress-test/reports/
# mcp
.playwright-mcp/
.serena/
.serena/
# settings
*.local.json

View File

@ -8,8 +8,7 @@
"module": "flask",
"env": {
"FLASK_APP": "app.py",
"FLASK_ENV": "development",
"GEVENT_SUPPORT": "True"
"FLASK_ENV": "development"
},
"args": [
"run",
@ -28,9 +27,7 @@
"type": "debugpy",
"request": "launch",
"module": "celery",
"env": {
"GEVENT_SUPPORT": "True"
},
"env": {},
"args": [
"-A",
"app.celery",
@ -40,7 +37,7 @@
"-c",
"1",
"-Q",
"dataset,generation,mail,ops_trace",
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline",
"--loglevel",
"INFO"
],

View File

@ -63,7 +63,7 @@ Dify is an open-source platform for developing LLM applications. Its intuitive i
> - CPU >= 2 Core
> - RAM >= 4 GiB
</br>
<br/>
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
@ -109,15 +109,15 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
## Using Dify
- **Cloud </br>**
- **Cloud <br/>**
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
- **Self-hosting Dify Community Edition</br>**
- **Self-hosting Dify Community Edition<br/>**
Quickly get Dify running in your environment with this [starter guide](#quick-start).
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
- **Dify for enterprise / organizations</br>**
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. </br>
- **Dify for enterprise / organizations<br/>**
We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.

View File

@ -27,6 +27,9 @@ FILES_URL=http://localhost:5001
# Example: INTERNAL_FILES_URL=http://api:5001
INTERNAL_FILES_URL=http://127.0.0.1:5001
# TRIGGER URL
TRIGGER_URL=http://localhost:5001
# The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300
@ -156,6 +159,9 @@ SUPABASE_URL=your-server-url
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Set COOKIE_DOMAIN when the console frontend and API are on different subdomains.
# Provide the registrable domain (e.g. example.com); leading dots are optional.
COOKIE_DOMAIN=
# Vector database configuration
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
@ -368,6 +374,12 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Comma-separated list of file extensions blocked from upload for security reasons.
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
# Empty by default to allow all file types.
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
UPLOAD_FILE_EXTENSION_BLACKLIST=
# Model configuration
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
@ -457,6 +469,9 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
HTTP_REQUEST_NODE_SSL_VERIFY=True
# Webhook request configuration
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
# Respect X-* headers to redirect clients
RESPECT_XFORWARD_HEADERS_ENABLED=false
@ -512,7 +527,7 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# Workflow log cleanup configuration
# Enable automatic cleanup of workflow run logs to manage database size
WORKFLOW_LOG_CLEANUP_ENABLED=true
WORKFLOW_LOG_CLEANUP_ENABLED=false
# Number of days to retain workflow run logs (default: 30 days)
WORKFLOW_LOG_RETENTION_DAYS=30
# Batch size for workflow log cleanup operations (default: 100)
@ -534,6 +549,12 @@ ENABLE_CLEAN_MESSAGES=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
# Interval time in minutes for polling scheduled workflows(default: 1 min)
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
# Position configuration
POSITION_TOOL_PINS=
@ -605,3 +626,9 @@ SWAGGER_UI_PATH=/swagger-ui.html
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
# Set to false to export dataset IDs as plain text for easier cross-environment import
DSL_EXPORT_ENCRYPT_DATASET_ID=true
# Tenant isolated task queue configuration
TENANT_ISOLATED_TASK_CONCURRENCY=1
# Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0

View File

@ -54,7 +54,7 @@
"--loglevel",
"DEBUG",
"-Q",
"dataset,generation,mail,ops_trace,app_deletion"
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
]
}
]

62
api/AGENTS.md Normal file
View File

@ -0,0 +1,62 @@
# Agent Skill Index
Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
______________________________________________________________________
## Platform Foundations
- **[Infrastructure Overview](agent_skills/infra.md)**\
When to read this:
- You need to understand where a feature belongs in the architecture.
- Youre wiring storage, Redis, vector stores, or OTEL.
- Youre about to add CLI commands or async jobs.\
What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands.
- **[Coding Style](agent_skills/coding_style.md)**\
When to read this:
- Youre writing or reviewing backend code and need the authoritative checklist.
- Youre unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns.
- You want the exact lint/type/test commands used in PRs.\
Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management.
______________________________________________________________________
## Plugin & Extension Development
- **[Plugin Systems](agent_skills/plugin.md)**\
When to read this:
- Youre building or debugging a marketplace plugin.
- You need to know how manifests, providers, daemons, and migrations fit together.\
What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform.
- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\
When to read this:
- You must integrate OAuth for a plugin or datasource.
- Youre handling credential encryption or refresh flows.\
Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows.
______________________________________________________________________
## Workflow Entry & Execution
- **[Trigger Concepts](agent_skills/trigger.md)**\
When to read this:
- Youre debugging why a workflow didnt start.
- Youre adding a new trigger type or hook.
- You need to trace async execution, draft debugging, or webhook/schedule pipelines.\
Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions.
______________________________________________________________________
## Additional Notes for Agents
- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes.
- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.

View File

@ -15,7 +15,11 @@ FROM base AS packages
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
&& apt-get install -y --no-install-recommends \
# basic environment
g++ \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies
COPY pyproject.toml uv.lock ./
@ -49,7 +53,9 @@ RUN \
# Install dependencies
&& apt-get install -y --no-install-recommends \
# basic environment
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
curl nodejs \
# for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \
# For Security
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
# install fonts to support the use of tools like pypdfium2

View File

@ -80,7 +80,7 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
```
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:

View File

@ -0,0 +1,115 @@
## Linter
- Always follow `.ruff.toml`.
- Run `uv run ruff check --fix --unsafe-fixes`.
- Keep each line under 100 characters (including spaces).
## Code Style
- `snake_case` for variables and functions.
- `PascalCase` for classes.
- `UPPER_CASE` for constants.
## Rules
- Use Pydantic v2 standard.
- Use `uv` for package management.
- Do not override dunder methods like `__init__`, `__iadd__`, etc.
- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed.
- Prefer simple functions over classes for lightweight helpers.
- Keep files below 800 lines; split when necessary.
- Keep code readable—no clever hacks.
- Never use `print`; log with `logger = logging.getLogger(__name__)`.
## Guiding Principles
- Mirror the projects layered architecture: controller → service → core/domain.
- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
- Optimise for observability: deterministic control flow, clear logging, actionable errors.
## SQLAlchemy Patterns
- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines.
- Open sessions with context managers:
```python
from sqlalchemy.orm import Session
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Workflow).where(
Workflow.id == workflow_id,
Workflow.tenant_id == tenant_id,
)
workflow = session.execute(stmt).scalar_one_or_none()
```
- Use SQLAlchemy expressions; avoid raw SQL unless necessary.
- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies.
- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
## Storage & External IO
- Access storage via `extensions.ext_storage.storage`.
- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
- Background tasks that touch storage must be idempotent and log the relevant object identifiers.
## Pydantic Usage
- Define DTOs with Pydantic v2 models and forbid extras by default.
- Use `@field_validator` / `@model_validator` for domain rules.
- Example:
```python
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
class TriggerConfig(BaseModel):
endpoint: HttpUrl
secret: str
model_config = ConfigDict(extra="forbid")
@field_validator("secret")
def ensure_secret_prefix(cls, value: str) -> str:
if not value.startswith("dify_"):
raise ValueError("secret must start with dify_")
return value
```
## Generics & Protocols
- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
## Error Handling & Logging
- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers.
- Declare `logger = logging.getLogger(__name__)` at module top.
- Include tenant/app/workflow identifiers in log context.
- Log retryable events at `warning`, terminal failures at `error`.
## Tooling & Checks
- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`.
- Type checks: `uv run --directory api --dev basedpyright`.
- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
- Run all of the above before submitting your work.
## Controllers & Services
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables.
- Document non-obvious behaviour with concise comments.
## Miscellaneous
- Use `configs.dify_config` for configuration—never read environment variables directly.
- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
- Keep experimental scripts under `dev/`; do not ship them in production builds.

96
api/agent_skills/infra.md Normal file
View File

@ -0,0 +1,96 @@
## Configuration
- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly.
- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`.
- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing.
- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`.
## Dependencies
- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`.
- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group.
- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current.
## Storage & Files
- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend.
- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads.
- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly.
- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform.
## Redis & Shared State
- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`.
- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`.
## Models
- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`).
- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn.
- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories.
- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below.
## Vector Stores
- Vector client implementations live in `core/rag/datasource/vdb/<provider>`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`.
- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`.
- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions.
- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations.
## Observability & OTEL
- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads.
- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints.
- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`).
- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`.
## Ops Integrations
- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above.
- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules.
- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata.
## Controllers, Services, Core
- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`.
- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs).
- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`.
## Plugins, Tools, Providers
- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation.
- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`.
- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way.
- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application.
- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config).
- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly.
## Async Workloads
see `agent_skills/trigger.md` for more detailed documentation.
- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`.
- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc.
- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs.
## Database & Migrations
- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`.
- Generate migrations with `uv run --project api flask db revision --autogenerate -m "<summary>"`, then review the diff; never hand-edit the database outside Alembic.
- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history.
- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables.
## CLI Commands
- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask <command>`.
- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour.
- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations.
- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR.
- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes).
## When You Add Features
- Check for an existing helper or service before writing a new util.
- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`.
- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations).
- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes.

View File

@ -0,0 +1 @@
// TBD

View File

@ -0,0 +1 @@
// TBD

View File

@ -0,0 +1,53 @@
## Overview
Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node.
## Trigger nodes
- `UserInput`
- `Trigger Webhook`
- `Trigger Schedule`
- `Trigger Plugin`
### UserInput
Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app`
1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool.
1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node.
1. For its detailed implementation, please refer to `core/workflow/nodes/start`
### Trigger Webhook
Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`.
Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution.
### Trigger Schedule
`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help.
To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published.
### Trigger Plugin
`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it.
1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint`
1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details.
A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one.
## Worker Pool / Async Task
All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`.
The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`.
## Debug Strategy
Dify divided users into 2 groups: builders / end users.
Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`.
A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type.

View File

@ -1,7 +1,7 @@
import sys
def is_db_command():
def is_db_command() -> bool:
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
return True
return False
@ -13,23 +13,12 @@ if is_db_command():
app = create_migrations_app()
else:
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
# from gevent import monkey
# Gunicorn and Celery handle monkey patching automatically in production by
# specifying the `gevent` worker class. Manual monkey patching is not required here.
#
# # gevent
# monkey.patch_all()
# See `api/docker/entrypoint.sh` (lines 33 and 47) for details.
#
# from grpc.experimental import gevent as grpc_gevent # type: ignore
#
# # grpc gevent
# grpc_gevent.init_gevent()
# import psycogreen.gevent # type: ignore
#
# psycogreen.gevent.patch_psycopg()
# For third-party library patching, refer to `gunicorn.conf.py` and `celery_entrypoint.py`.
from app_factory import create_app

View File

@ -15,12 +15,12 @@ from sqlalchemy.orm import sessionmaker
from configs import dify_config
from constants.languages import languages
from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import Document
from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created
from extensions.ext_database import db
@ -321,6 +321,8 @@ def migrate_knowledge_vector_database():
)
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
if not datasets.items:
break
except SQLAlchemyError:
raise
@ -1227,6 +1229,55 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_system_trigger_oauth_client(provider, client_params):
"""
Setup system trigger oauth client
"""
from models.provider_ids import TriggerProviderID
from models.trigger import TriggerOAuthSystemClient
provider_id = TriggerProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(TriggerOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
oauth_client = TriggerOAuthSystemClient(
provider=provider_name,
plugin_id=plugin_id,
encrypted_oauth_params=oauth_client_params,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
"""
Find draft variables that reference non-existent apps.
@ -1420,7 +1471,10 @@ def setup_datasource_oauth_client(provider, client_params):
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
def transform_datasource_credentials():
@click.option(
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
)
def transform_datasource_credentials(environment: str):
"""
Transform datasource credentials
"""
@ -1431,9 +1485,14 @@ def transform_datasource_credentials():
notion_plugin_id = "langgenius/notion_datasource"
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
jina_plugin_id = "langgenius/jina_datasource"
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
if environment == "online":
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
else:
notion_plugin_unique_identifier = None
firecrawl_plugin_unique_identifier = None
jina_plugin_unique_identifier = None
oauth_credential_type = CredentialType.OAUTH2
api_key_credential_type = CredentialType.API_KEY
@ -1599,7 +1658,7 @@ def transform_datasource_credentials():
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jina",
provider="jinareader",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,

View File

@ -174,6 +174,33 @@ class CodeExecutionSandboxConfig(BaseSettings):
)
class TriggerConfig(BaseSettings):
"""
Configuration for trigger
"""
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
description="Maximum allowed size for webhook request bodies in bytes",
default=10485760,
)
class AsyncWorkflowConfig(BaseSettings):
"""
Configuration for async workflow
"""
ASYNC_WORKFLOW_SCHEDULER_GRANULARITY: int = Field(
description="Granularity for async workflow scheduler, "
"sometime, few users could block the queue due to some time-consuming tasks, "
"to avoid this, workflow can be suspended if needed, to achieve"
"this, a time-based checker is required, every granularity seconds, "
"the checker will check the workflow queue and suspend the workflow",
default=120,
ge=1,
)
class PluginConfig(BaseSettings):
"""
Plugin configs
@ -263,6 +290,8 @@ class EndpointConfig(BaseSettings):
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
)
TRIGGER_URL: str = Field(description="Template url for triggers", default="http://localhost:5001")
class FileAccessConfig(BaseSettings):
"""
@ -331,12 +360,42 @@ class FileUploadConfig(BaseSettings):
default=10,
)
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
"Empty by default to allow all file types."
),
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
default="",
)
@computed_field # type: ignore[misc]
@property
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
"""
Parse and return the blacklist as a set of lowercase extensions.
Returns an empty set if no blacklist is configured.
"""
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
return set()
return {
ext.strip().lower().strip(".")
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
if ext.strip()
}
class HttpConfig(BaseSettings):
"""
HTTP-related configurations for the application
"""
COOKIE_DOMAIN: str = Field(
description="Explicit cookie domain for console/service cookies when sharing across subdomains",
default="",
)
API_COMPRESSION_ENABLED: bool = Field(
description="Enable or disable gzip compression for HTTP responses",
default=False,
@ -915,6 +974,11 @@ class DataSetConfig(BaseSettings):
default=True,
)
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
description="Maximum number of segments for dataset segments API (0 for unlimited)",
default=0,
)
class WorkspaceConfig(BaseSettings):
"""
@ -990,6 +1054,44 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable check upgradable plugin task",
default=True,
)
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
description="Enable workflow schedule poller task",
default=True,
)
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
description="Workflow schedule poller interval in minutes",
default=1,
)
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
description="Maximum number of schedules to process in each poll batch",
default=100,
)
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
default=0,
)
# Trigger provider refresh (simple version)
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
description="Enable trigger provider refresh poller",
default=True,
)
TRIGGER_PROVIDER_REFRESH_INTERVAL: int = Field(
description="Trigger provider refresh poller interval in minutes",
default=1,
)
TRIGGER_PROVIDER_REFRESH_BATCH_SIZE: int = Field(
description="Max trigger subscriptions to process per tick",
default=200,
)
TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field(
description="Proactive credential refresh threshold in seconds",
default=180,
)
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
description="Proactive subscription refresh threshold in seconds",
default=60 * 60,
)
class PositionConfig(BaseSettings):
@ -1088,7 +1190,7 @@ class AccountConfig(BaseSettings):
class WorkflowLogConfig(BaseSettings):
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup")
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=False, description="Enable workflow run log cleanup")
WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs")
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
default=100, description="Batch size for workflow run log cleanup operations"
@ -1107,12 +1209,21 @@ class SwaggerUIConfig(BaseSettings):
)
class TenantIsolatedTaskQueueConfig(BaseSettings):
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
default=1,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
TriggerConfig,
AsyncWorkflowConfig,
PluginConfig,
MarketplaceConfig,
DataSetConfig,
@ -1131,6 +1242,7 @@ class FeatureConfig(
RagEtlConfig,
RepositoryConfig,
SecurityConfig,
TenantIsolatedTaskQueueConfig,
ToolConfig,
UpdateConfig,
WorkflowConfig,

View File

@ -22,6 +22,11 @@ class WeaviateConfig(BaseSettings):
default=True,
)
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
default=None,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description="Number of objects to be processed in a single batch operation (default is 100)",
default=100,

View File

@ -56,11 +56,15 @@ else:
}
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
# console
COOKIE_NAME_ACCESS_TOKEN = "access_token"
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
COOKIE_NAME_PASSPORT = "passport"
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
# webapp
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
COOKIE_NAME_PASSPORT = "passport"
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
HEADER_NAME_APP_CODE = "X-App-Code"
HEADER_NAME_PASSPORT = "X-App-Passport"

File diff suppressed because one or more lines are too long

View File

@ -9,6 +9,7 @@ if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.trigger.provider import PluginTriggerProviderController
"""
@ -41,3 +42,11 @@ datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginPro
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("datasource_plugin_providers_lock")
)
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
ContextVar("plugin_trigger_providers")
)
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_trigger_providers_lock")
)

View File

@ -25,6 +25,12 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415
class BlockedFileExtensionError(BaseHTTPException):
error_code = "file_extension_blocked"
description = "The file extension is blocked for security reasons."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."

View File

@ -66,6 +66,7 @@ from .app import (
workflow_draft_variable,
workflow_run,
workflow_statistic,
workflow_trigger,
)
# Import auth controllers
@ -126,6 +127,7 @@ from .workspace import (
models,
plugin,
tool_providers,
trigger_providers,
workspace,
)
@ -196,6 +198,7 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
"trigger_providers",
"version",
"website",
"workflow",
@ -203,5 +206,6 @@ __all__ = [
"workflow_draft_variable",
"workflow_run",
"workflow_statistic",
"workflow_trigger",
"workspace",
]

View File

@ -5,18 +5,20 @@ from controllers.console.wraps import account_initialization_required, setup_req
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
parser = (
reqparse.RequestParser()
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
)
@console_ns.route("/app/prompt-templates")
class AdvancedPromptTemplateList(Resource):
@api.doc("get_advanced_prompt_templates")
@api.doc(description="Get advanced prompt templates based on app mode and model configuration")
@api.expect(
api.parser()
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
.add_argument("has_context", type=str, default="true", location="args", help="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
)
@api.expect(parser)
@api.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
)
@ -25,13 +27,6 @@ class AdvancedPromptTemplateList(Resource):
@login_required
@account_initialization_required
def get(self):
parser = (
reqparse.RequestParser()
.add_argument("app_mode", type=str, required=True, location="args")
.add_argument("model_mode", type=str, required=True, location="args")
.add_argument("has_context", type=str, required=False, default="true", location="args")
.add_argument("model_name", type=str, required=True, location="args")
)
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args)

View File

@ -8,17 +8,19 @@ from libs.login import login_required
from models.model import AppMode
from services.agent_service import AgentService
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
)
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
class AgentLogApi(Resource):
@api.doc("get_agent_logs")
@api.doc(description="Get agent execution logs for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("message_id", type=str, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID")
)
@api.expect(parser)
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
@api.response(400, "Invalid request parameters")
@setup_required
@ -27,12 +29,6 @@ class AgentLogApi(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model):
"""Get agent logs"""
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args")
.add_argument("conversation_id", type=uuid_value, required=True, location="args")
)
args = parser.parse_args()
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])

View File

@ -16,6 +16,7 @@ from fields.annotation_fields import (
annotation_fields,
annotation_hit_history_fields,
)
from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import AppAnnotationService
@ -175,8 +176,10 @@ class AnnotationApi(Resource):
api.model(
"CreateAnnotationRequest",
{
"question": fields.String(required=True, description="Question text"),
"answer": fields.String(required=True, description="Answer text"),
"message_id": fields.String(description="Message ID (optional)"),
"question": fields.String(description="Question text (required when message_id not provided)"),
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
"content": fields.String(description="Content text (use 'answer' or 'content')"),
"annotation_reply": fields.Raw(description="Annotation reply data"),
},
)
@ -193,11 +196,14 @@ class AnnotationApi(Resource):
app_id = str(app_id)
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
.add_argument("message_id", required=False, type=uuid_value, location="json")
.add_argument("question", required=False, type=str, location="json")
.add_argument("answer", required=False, type=str, location="json")
.add_argument("content", required=False, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
return annotation
@setup_required
@ -245,6 +251,13 @@ class AnnotationExportApi(Resource):
return response, 200
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource):
@api.doc("update_delete_annotation")
@ -253,6 +266,7 @@ class AnnotationUpdateDeleteApi(Resource):
@api.response(200, "Annotation updated successfully", annotation_fields)
@api.response(204, "Annotation deleted successfully")
@api.response(403, "Insufficient permissions")
@api.expect(parser)
@setup_required
@login_required
@account_initialization_required
@ -262,11 +276,6 @@ class AnnotationUpdateDeleteApi(Resource):
def post(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation

View File

@ -15,11 +15,12 @@ from controllers.console.wraps import (
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.workflow.enums import NodeType
from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import App
from models import App, Workflow
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@ -106,6 +107,35 @@ class AppListApi(Resource):
if str(app.id) in res:
app.access_mode = res[str(app.id)].access_mode
workflow_capable_app_ids = [
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
]
draft_trigger_app_ids: set[str] = set()
if workflow_capable_app_ids:
draft_workflows = (
db.session.execute(
select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids),
)
)
.scalars()
.all()
)
trigger_node_types = {
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
for _, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
return marshal(app_pagination, app_pagination_fields), 200
@api.doc("create_app")
@ -353,12 +383,15 @@ class AppExportApi(Resource):
}
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@api.doc("check_app_name")
@api.doc(description="Check if app name is available")
@api.doc(params={"app_id": "Application ID"})
@api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check"))
@api.expect(parser)
@api.response(200, "Name availability checked")
@setup_required
@login_required
@ -367,7 +400,6 @@ class AppNameApi(Resource):
@marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model):
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
app_service = AppService()

View File

@ -1,6 +1,7 @@
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
@ -18,9 +19,23 @@ from services.feature_service import FeatureService
from .. import console_ns
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("app_id", type=str, location="json")
)
@console_ns.route("/apps/imports")
class AppImportApi(Resource):
@api.expect(parser)
@setup_required
@login_required
@account_initialization_required
@ -30,18 +45,6 @@ class AppImportApi(Resource):
def post(self):
# Check user role first
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("app_id", type=str, location="json")
)
args = parser.parse_args()
# Create service with session

View File

@ -1,7 +1,5 @@
from datetime import datetime
import pytz
import sqlalchemy as sa
from flask import abort
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import func, or_
@ -19,7 +17,7 @@ from fields.conversation_fields import (
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
@ -90,25 +88,17 @@ class CompletionConversationApi(Resource):
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file
@ -270,29 +260,21 @@ class ChatConversationApi(Resource):
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)

View File

@ -11,6 +11,7 @@ from controllers.console.app.error import (
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
@ -206,13 +207,11 @@ class InstructionGenerateApi(Resource):
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
code_template = (
Python3CodeProvider.get_default_code()
if args["language"] == "python"
else (JavascriptCodeProvider.get_default_code())
if args["language"] == "javascript"
else ""
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args["language"])), None
)
code_template = code_provider.get_default_code() if code_provider else ""
try:
# Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":

View File

@ -16,7 +16,6 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
@ -24,12 +23,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields
from fields.conversation_fields import message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
@ -194,45 +192,6 @@ class MessageFeedbackApi(Resource):
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/annotations")
class MessageAnnotationApi(Resource):
@api.doc("create_message_annotation")
@api.doc(description="Create message annotation")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"MessageAnnotationRequest",
{
"message_id": fields.String(description="Message ID"),
"question": fields.String(required=True, description="Question text"),
"answer": fields.String(required=True, description="Answer text"),
"annotation_reply": fields.Raw(description="Annotation reply"),
},
)
)
@api.response(200, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions")
@marshal_with(annotation_fields)
@get_app_model
@setup_required
@login_required
@cloud_edition_billing_resource_check("annotation")
@account_initialization_required
@edit_permission_required
def post(self, app_model):
parser = (
reqparse.RequestParser()
.add_argument("message_id", required=False, type=uuid_value, location="json")
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
return annotation
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
class MessageAnnotationCountApi(Resource):
@api.doc("get_annotation_count")

View File

@ -1,9 +1,7 @@
from datetime import datetime
from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask import abort, jsonify
from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns
@ -11,6 +9,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models import AppMode, Message
@ -56,26 +55,16 @@ WHERE
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@ -91,16 +80,19 @@ WHERE
return jsonify({"data": response_data})
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
class DailyConversationStatistic(Resource):
@api.doc("get_daily_conversation_statistics")
@api.doc(description="Get daily conversation statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Daily conversation statistics retrieved successfully",
@ -113,15 +105,13 @@ class DailyConversationStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
stmt = (
sa.select(
@ -134,18 +124,10 @@ class DailyConversationStatistic(Resource):
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
)
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
if start_datetime_utc:
stmt = stmt.where(Message.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
stmt = stmt.where(Message.created_at < end_datetime_utc)
stmt = stmt.group_by("date").order_by("date")
@ -164,11 +146,7 @@ class DailyTerminalsStatistic(Resource):
@api.doc("get_daily_terminals_statistics")
@api.doc(description="Get daily terminal/end-user statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Daily terminal statistics retrieved successfully",
@ -181,11 +159,6 @@ class DailyTerminalsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@ -198,26 +171,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@ -238,11 +202,7 @@ class DailyTokenCostStatistic(Resource):
@api.doc("get_daily_token_cost_statistics")
@api.doc(description="Get daily token cost statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Daily token cost statistics retrieved successfully",
@ -255,11 +215,6 @@ class DailyTokenCostStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@ -273,26 +228,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@ -315,11 +261,7 @@ class AverageSessionInteractionStatistic(Resource):
@api.doc("get_average_session_interaction_statistics")
@api.doc(description="Get average session interaction statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Average session interaction statistics retrieved successfully",
@ -332,11 +274,6 @@ class AverageSessionInteractionStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@ -357,26 +294,17 @@ FROM
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND c.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND c.created_at < :end"
arg_dict["end"] = end_datetime_utc
@ -408,11 +336,7 @@ class UserSatisfactionRateStatistic(Resource):
@api.doc("get_user_satisfaction_rate_statistics")
@api.doc(description="Get user satisfaction rate statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"User satisfaction rate statistics retrieved successfully",
@ -425,11 +349,6 @@ class UserSatisfactionRateStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@ -446,26 +365,17 @@ WHERE
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND m.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND m.created_at < :end"
arg_dict["end"] = end_datetime_utc
@ -491,11 +401,7 @@ class AverageResponseTimeStatistic(Resource):
@api.doc("get_average_response_time_statistics")
@api.doc(description="Get average response time statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Average response time statistics retrieved successfully",
@ -508,11 +414,6 @@ class AverageResponseTimeStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@ -525,26 +426,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
@ -565,11 +457,7 @@ class TokensPerSecondStatistic(Resource):
@api.doc("get_tokens_per_second_statistics")
@api.doc(description="Get tokens per second statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Tokens per second statistics retrieved successfully",
@ -581,12 +469,6 @@ class TokensPerSecondStatistic(Resource):
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
@ -602,26 +484,17 @@ WHERE
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
if end_datetime_utc:
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc

View File

@ -16,9 +16,19 @@ from controllers.console.wraps import account_initialization_required, edit_perm
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginInvokeError
from core.trigger.debug.event_selectors import (
TriggerDebugEvent,
TriggerDebugEventPoller,
create_event_poller,
select_trigger_debug_events,
)
from core.workflow.enums import NodeType
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from factories import file_factory, variable_factory
@ -37,6 +47,7 @@ from services.errors.llm import InvokeRateLimitError
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
LISTENING_RETRY_IN = 2000
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
@ -102,7 +113,18 @@ class DraftWorkflowApi(Resource):
},
)
)
@api.response(200, "Draft workflow synced successfully", workflow_fields)
@api.response(
200,
"Draft workflow synced successfully",
api.model(
"SyncDraftWorkflowResponse",
{
"result": fields.String,
"hash": fields.String,
"updated_at": fields.String,
},
),
)
@api.response(400, "Invalid workflow configuration")
@api.response(403, "Permission denied")
@edit_permission_required
@ -564,6 +586,13 @@ class DraftWorkflowNodeRunApi(Resource):
return workflow_node_execution
parser_publish = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
class PublishedWorkflowApi(Resource):
@api.doc("get_published_workflow")
@ -588,6 +617,7 @@ class PublishedWorkflowApi(Resource):
# return workflow, if not found, return None
return workflow
@api.expect(parser_publish)
@setup_required
@login_required
@account_initialization_required
@ -598,12 +628,8 @@ class PublishedWorkflowApi(Resource):
Publish workflow
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
args = parser.parse_args()
args = parser_publish.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
@ -658,6 +684,9 @@ class DefaultBlockConfigsApi(Resource):
return workflow_service.get_default_block_configs()
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultBlockConfigApi(Resource):
@api.doc("get_default_block_config")
@ -665,6 +694,7 @@ class DefaultBlockConfigApi(Resource):
@api.doc(params={"app_id": "Application ID", "block_type": "Block type"})
@api.response(200, "Default block configuration retrieved successfully")
@api.response(404, "Block type not found")
@api.expect(parser_block)
@setup_required
@login_required
@account_initialization_required
@ -674,8 +704,7 @@ class DefaultBlockConfigApi(Resource):
"""
Get default block config
"""
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
args = parser.parse_args()
args = parser_block.parse_args()
q = args.get("q")
@ -691,8 +720,18 @@ class DefaultBlockConfigApi(Resource):
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
parser_convert = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
class ConvertToWorkflowApi(Resource):
@api.expect(parser_convert)
@api.doc("convert_to_workflow")
@api.doc(description="Convert application to workflow mode")
@api.doc(params={"app_id": "Application ID"})
@ -713,14 +752,7 @@ class ConvertToWorkflowApi(Resource):
current_user, _ = current_account_with_tenant()
if request.data:
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_convert.parse_args()
else:
args = {}
@ -734,8 +766,18 @@ class ConvertToWorkflowApi(Resource):
}
parser_workflows = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource):
@api.expect(parser_workflows)
@api.doc("get_all_published_workflows")
@api.doc(description="Get all published workflows for an application")
@api.doc(params={"app_id": "Application ID"})
@ -752,16 +794,9 @@ class PublishedAllWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
args = parser.parse_args()
page = int(args.get("page", 1))
limit = int(args.get("limit", 10))
args = parser_workflows.parse_args()
page = args["page"]
limit = args["limit"]
user_id = args.get("user_id")
named_only = args.get("named_only", False)
@ -915,3 +950,234 @@ class DraftWorkflowNodeLastRunApi(Resource):
if node_exec is None:
raise NotFound("last run not found")
return node_exec
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run")
class DraftWorkflowTriggerRunApi(Resource):
"""
Full workflow debug - Polling API for trigger events
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
"""
@api.doc("poll_draft_workflow_trigger_run")
@api.doc(description="Poll for trigger events and execute full workflow when event arrives")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"DraftWorkflowTriggerRunRequest",
{
"node_id": fields.String(required=True, description="Node ID"),
},
)
)
@api.response(200, "Trigger event received and workflow executed successfully")
@api.response(403, "Permission denied")
@api.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Poll for trigger events and execute full workflow when event arrives
"""
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
args = parser.parse_args()
node_id = args["node_id"]
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
raise ValueError("Workflow not found")
poller: TriggerDebugEventPoller = create_event_poller(
draft_workflow=draft_workflow,
tenant_id=app_model.tenant_id,
user_id=current_user.id,
app_id=app_model.id,
node_id=node_id,
)
event: TriggerDebugEvent | None = None
try:
event = poller.poll()
if not event:
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
workflow_args = dict(event.workflow_args)
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
return helper.compact_generate_response(
AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=workflow_args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
root_node_id=node_id,
)
)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except PluginInvokeError as e:
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
except Exception as e:
logger.exception("Error polling trigger debug event")
raise e
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run")
class DraftWorkflowTriggerNodeApi(Resource):
"""
Single node debug - Polling API for trigger events
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run
"""
@api.doc("poll_draft_workflow_trigger_node")
@api.doc(description="Poll for trigger events and execute single node when event arrives")
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@api.response(200, "Trigger event received and node executed successfully")
@api.response(403, "Permission denied")
@api.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Poll for trigger events and execute single node when event arrives
"""
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
raise ValueError("Workflow not found")
node_config = draft_workflow.get_node_config_by_id(node_id=node_id)
if not node_config:
raise ValueError("Node data not found for node %s", node_id)
node_type: NodeType = draft_workflow.get_node_type_from_node_config(node_config)
event: TriggerDebugEvent | None = None
# for schedule trigger, when run single node, just execute directly
if node_type == NodeType.TRIGGER_SCHEDULE:
event = TriggerDebugEvent(
workflow_args={},
node_id=node_id,
)
# for other trigger types, poll for the event
else:
try:
poller: TriggerDebugEventPoller = create_event_poller(
draft_workflow=draft_workflow,
tenant_id=app_model.tenant_id,
user_id=current_user.id,
app_id=app_model.id,
node_id=node_id,
)
event = poller.poll()
except PluginInvokeError as e:
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
except Exception as e:
logger.exception("Error polling trigger debug event")
raise e
if not event:
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
raw_files = event.workflow_args.get("files")
files = _parse_file(draft_workflow, raw_files if isinstance(raw_files, list) else None)
try:
node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model,
draft_workflow=draft_workflow,
node_id=node_id,
user_inputs=event.workflow_args.get("inputs") or {},
account=current_user,
query="",
files=files,
)
return jsonable_encoder(node_execution)
except Exception as e:
logger.exception("Error running draft workflow trigger node")
return jsonable_encoder(
{"status": "error", "error": "An unexpected error occurred while running the node."}
), 400
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run-all")
class DraftWorkflowTriggerRunAllApi(Resource):
"""
Full workflow debug - Polling API for trigger events
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run-all
"""
@api.doc("draft_workflow_trigger_run_all")
@api.doc(description="Full workflow debug when the start node is a trigger")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.model(
"DraftWorkflowTriggerRunAllRequest",
{
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
},
)
)
@api.response(200, "Workflow executed successfully")
@api.response(403, "Permission denied")
@api.response(500, "Internal server error")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Full workflow debug when the start node is a trigger
"""
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("node_ids", type=list, required=True, location="json", nullable=False)
args = parser.parse_args()
node_ids = args["node_ids"]
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
raise ValueError("Workflow not found")
try:
trigger_debug_event: TriggerDebugEvent | None = select_trigger_debug_events(
draft_workflow=draft_workflow,
app_model=app_model,
user_id=current_user.id,
node_ids=node_ids,
)
except PluginInvokeError as e:
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
except Exception as e:
logger.exception("Error polling trigger debug event")
raise e
if trigger_debug_event is None:
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
try:
workflow_args = dict(trigger_debug_event.workflow_args)
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=workflow_args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
root_node_id=trigger_debug_event.node_id,
)
return helper.compact_generate_response(response)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except Exception:
logger.exception("Error running draft workflow trigger run-all")
return jsonable_encoder(
{
"status": "error",
}
), 400

View File

@ -28,6 +28,7 @@ class WorkflowAppLogApi(Resource):
"created_at__after": "Filter logs created after this timestamp",
"created_by_end_user_session_id": "Filter by end user session ID",
"created_by_account": "Filter by account",
"detail": "Whether to return detailed logs",
"page": "Page number (1-99999)",
"limit": "Number of items per page (1-100)",
}
@ -68,6 +69,7 @@ class WorkflowAppLogApi(Resource):
required=False,
default=None,
)
.add_argument("detail", type=bool, location="args", required=False, default=False)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
@ -92,6 +94,7 @@ class WorkflowAppLogApi(Resource):
created_at_after=args.created_at__after,
page=args.page,
limit=args.limit,
detail=args.detail,
created_by_end_user_session_id=args.created_by_end_user_session_id,
created_by_account=args.created_by_account,
)

View File

@ -30,23 +30,25 @@ def _parse_workflow_run_list_args():
Returns:
Parsed arguments containing last_id, limit, status, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
)
return parser.parse_args()
@ -58,28 +60,30 @@ def _parse_workflow_run_count_args():
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
parser = (
reqparse.RequestParser()
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
)
return parser.parse_args()

View File

@ -1,23 +1,26 @@
from datetime import datetime
from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask import abort, jsonify
from flask_restx import Resource, reqparse
from sqlalchemy.orm import sessionmaker
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
class WorkflowDailyRunsStatistic(Resource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_daily_runs_statistic")
@api.doc(description="Get workflow daily runs statistics")
@api.doc(params={"app_id": "Application ID"})
@ -37,57 +40,32 @@ class WorkflowDailyRunsStatistic(Resource):
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(id) AS runs
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "runs": i.runs})
response_data = self._workflow_run_repo.get_daily_runs_statistics(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_date=start_date,
end_date=end_date,
timezone=account.timezone,
)
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
class WorkflowDailyTerminalsStatistic(Resource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_daily_terminals_statistic")
@api.doc(description="Get workflow daily terminals statistics")
@api.doc(params={"app_id": "Application ID"})
@ -107,57 +85,32 @@ class WorkflowDailyTerminalsStatistic(Resource):
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_date=start_date,
end_date=end_date,
timezone=account.timezone,
)
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs")
class WorkflowDailyTokenCostStatistic(Resource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_daily_token_cost_statistic")
@api.doc(description="Get workflow daily token cost statistics")
@api.doc(params={"app_id": "Application ID"})
@ -177,62 +130,32 @@ class WorkflowDailyTokenCostStatistic(Resource):
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
SUM(workflow_runs.total_tokens) AS token_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{
"date": str(i.date),
"token_count": i.token_count,
}
)
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_date=start_date,
end_date=end_date,
timezone=account.timezone,
)
return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions")
class WorkflowAverageAppInteractionStatistic(Resource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_average_app_interaction_statistic")
@api.doc(description="Get workflow average app interaction statistics")
@api.doc(params={"app_id": "Application ID"})
@ -252,67 +175,20 @@ class WorkflowAverageAppInteractionStatistic(Resource):
)
args = parser.parse_args()
sql_query = """SELECT
AVG(sub.interactions) AS interactions,
sub.date
FROM
(
SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
c.created_by,
COUNT(c.id) AS interactions
FROM
workflow_runs c
WHERE
c.app_id = :app_id
AND c.triggered_from = :triggered_from
{{start}}
{{end}}
GROUP BY
date, c.created_by
) sub
GROUP BY
sub.date"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
arg_dict["start"] = start_datetime_utc
else:
sql_query = sql_query.replace("{{start}}", "")
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
arg_dict["end"] = end_datetime_utc
else:
sql_query = sql_query.replace("{{end}}", "")
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
)
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
start_date=start_date,
end_date=end_date,
timezone=account.timezone,
)
return jsonify({"data": response_data})

View File

@ -0,0 +1,145 @@
import logging
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
from models.enums import AppTriggerStatus
from models.model import Account, App, AppMode
from models.trigger import AppTrigger, WorkflowWebhookTrigger
logger = logging.getLogger(__name__)
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(webhook_trigger_fields)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
args = parser.parse_args()
node_id = str(args["node_id"])
with Session(db.engine) as session:
# Get webhook trigger for this app and node
webhook_trigger = (
session.query(WorkflowWebhookTrigger)
.where(
WorkflowWebhookTrigger.app_id == app_model.id,
WorkflowWebhookTrigger.node_id == node_id,
)
.first()
)
if not webhook_trigger:
raise NotFound("Webhook trigger not found for this node")
return webhook_trigger
class AppTriggersApi(Resource):
"""App Triggers list API"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(triggers_list_fields)
def get(self, app_model: App):
"""Get app triggers list"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
with Session(db.engine) as session:
# Get all triggers for this app using select API
triggers = (
session.execute(
select(AppTrigger)
.where(
AppTrigger.tenant_id == current_user.current_tenant_id,
AppTrigger.app_id == app_model.id,
)
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
)
.scalars()
.all()
)
# Add computed icon field for each trigger
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
for trigger in triggers:
if trigger.trigger_type == "trigger-plugin":
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
else:
trigger.icon = "" # type: ignore
return {"data": triggers}
class AppTriggerEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_fields)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
parser = reqparse.RequestParser()
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if not current_user.has_edit_permission:
raise Forbidden()
trigger_id = args["trigger_id"]
with Session(db.engine) as session:
# Find the trigger using select
trigger = session.execute(
select(AppTrigger).where(
AppTrigger.id == trigger_id,
AppTrigger.tenant_id == current_user.current_tenant_id,
AppTrigger.app_id == app_model.id,
)
).scalar_one_or_none()
if not trigger:
raise NotFound("Trigger not found")
# Update status based on enable_trigger boolean
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)
# Add computed icon field
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
if trigger.trigger_type == "trigger-plugin":
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
else:
trigger.icon = "" # type: ignore
return trigger
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")

View File

@ -29,6 +29,7 @@ from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie,
extract_refresh_token,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
@ -270,7 +271,7 @@ class EmailCodeLoginApi(Resource):
class RefreshTokenApi(Resource):
def post(self):
# Get refresh token from cookie instead of request body
refresh_token = request.cookies.get("refresh_token")
refresh_token = extract_refresh_token(request)
if not refresh_token:
return {"result": "fail", "message": "No refresh token provided"}, 401

View File

@ -2,6 +2,7 @@ from flask_restx import Resource, reqparse
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@ -16,7 +17,13 @@ class Subscription(Resource):
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
.add_argument(
"plan",
type=str,
required=True,
location="args",
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
)
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
)
args = parser.parse_args()

View File

@ -746,7 +746,7 @@ class DocumentApi(DocumentResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
@ -779,7 +779,7 @@ class DocumentApi(DocumentResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,

View File

@ -3,7 +3,7 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
@ -121,8 +121,16 @@ class DatasourceOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_datasource = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource):
@api.expect(parser_datasource)
@setup_required
@login_required
@account_initialization_required
@ -130,14 +138,7 @@ class DatasourceAuth(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument(
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_datasource.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -168,8 +169,14 @@ class DatasourceAuth(Resource):
return {"result": datasources}, 200
parser_datasource_delete = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource):
@api.expect(parser_datasource_delete)
@setup_required
@login_required
@account_initialization_required
@ -181,10 +188,7 @@ class DatasourceAuthDeleteApi(Resource):
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
parser = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_datasource_delete.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_tenant_id,
@ -195,8 +199,17 @@ class DatasourceAuthDeleteApi(Resource):
return {"result": "success"}, 200
parser_datasource_update = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource):
@api.expect(parser_datasource_update)
@setup_required
@login_required
@account_initialization_required
@ -205,13 +218,7 @@ class DatasourceAuthUpdateApi(Resource):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_datasource_update.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
@ -251,8 +258,16 @@ class DatasourceHardCodeAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200
parser_datasource_custom = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource):
@api.expect(parser_datasource_custom)
@setup_required
@login_required
@account_initialization_required
@ -260,12 +275,7 @@ class DatasourceAuthOauthCustomClient(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_datasource_custom.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
@ -291,8 +301,12 @@ class DatasourceAuthOauthCustomClient(Resource):
return {"result": "success"}, 200
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource):
@api.expect(parser_default)
@setup_required
@login_required
@account_initialization_required
@ -300,8 +314,7 @@ class DatasourceAuthDefaultApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
args = parser_default.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
@ -312,8 +325,16 @@ class DatasourceAuthDefaultApi(Resource):
return {"result": "success"}, 200
parser_update_name = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource):
@api.expect(parser_update_name)
@setup_required
@login_required
@account_initialization_required
@ -321,12 +342,7 @@ class DatasourceUpdateProviderNameApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_update_name.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(

View File

@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore
)
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required
@ -12,9 +12,17 @@ from models import Account
from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource):
@api.expect(parser)
@setup_required
@login_required
@account_initialization_required
@ -26,12 +34,6 @@ class DataSourceContentPreviewApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args()
inputs = args.get("inputs")

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.app.error import (
ConversationCompletedError,
DraftWorkflowNotExist,
@ -148,8 +148,12 @@ class DraftRagPipelineApi(Resource):
}
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource):
@api.expect(parser_run)
@setup_required
@login_required
@account_initialization_required
@ -162,8 +166,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
args = parser_run.parse_args()
try:
response = PipelineGenerateService.generate_single_iteration(
@ -184,6 +187,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource):
@api.expect(parser_run)
@setup_required
@login_required
@account_initialization_required
@ -197,8 +201,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
args = parser_run.parse_args()
try:
response = PipelineGenerateService.generate_single_loop(
@ -217,8 +220,18 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
raise InternalServerError()
parser_draft_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource):
@api.expect(parser_draft_run)
@setup_required
@login_required
@account_initialization_required
@ -232,14 +245,7 @@ class DraftRagPipelineRunApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_draft_run.parse_args()
try:
response = PipelineGenerateService.generate(
@ -255,8 +261,21 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
parser_published_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource):
@api.expect(parser_published_run)
@setup_required
@login_required
@account_initialization_required
@ -270,17 +289,7 @@ class PublishedRagPipelineRunApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_published_run.parse_args()
streaming = args["response_mode"] == "streaming"
@ -381,8 +390,17 @@ class PublishedRagPipelineRunApi(Resource):
#
# return result
#
parser_rag_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@api.expect(parser_rag_run)
@setup_required
@login_required
@account_initialization_required
@ -396,13 +414,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_rag_run.parse_args()
inputs = args.get("inputs")
if inputs is None:
@ -429,6 +441,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource):
@api.expect(parser_rag_run)
@setup_required
@login_required
@account_initialization_required
@ -442,13 +455,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_rag_run.parse_args()
inputs = args.get("inputs")
if inputs is None:
@ -473,8 +480,14 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
)
parser_run_api = reqparse.RequestParser().add_argument(
"inputs", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource):
@api.expect(parser_run_api)
@setup_required
@login_required
@account_initialization_required
@ -489,10 +502,7 @@ class RagPipelineDraftNodeRunApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"inputs", type=dict, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_run_api.parse_args()
inputs = args.get("inputs")
if inputs == None:
@ -607,8 +617,12 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
return rag_pipeline_service.get_default_block_configs()
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource):
@api.expect(parser_default)
@setup_required
@login_required
@account_initialization_required
@ -622,8 +636,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
args = parser.parse_args()
args = parser_default.parse_args()
q = args.get("q")
@ -639,8 +652,18 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
parser_wf = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
class PublishedAllRagPipelineApi(Resource):
@api.expect(parser_wf)
@setup_required
@login_required
@account_initialization_required
@ -654,16 +677,9 @@ class PublishedAllRagPipelineApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
args = parser.parse_args()
page = int(args.get("page", 1))
limit = int(args.get("limit", 10))
args = parser_wf.parse_args()
page = args["page"]
limit = args["limit"]
user_id = args.get("user_id")
named_only = args.get("named_only", False)
@ -691,8 +707,16 @@ class PublishedAllRagPipelineApi(Resource):
}
parser_wf_id = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource):
@api.expect(parser_wf_id)
@setup_required
@login_required
@account_initialization_required
@ -707,19 +731,13 @@ class RagPipelineByIdApi(Resource):
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_wf_id.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
args = parser.parse_args()
# Prepare update data
update_data = {}
@ -752,8 +770,12 @@ class RagPipelineByIdApi(Resource):
return workflow
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource):
@api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -763,8 +785,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
"""
Get second step parameters of rag pipeline
"""
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
@ -777,6 +798,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource):
@api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -786,8 +808,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
"""
Get first step parameters of rag pipeline
"""
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
@ -800,6 +821,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource):
@api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -809,8 +831,7 @@ class DraftRagPipelineFirstStepApi(Resource):
"""
Get first step parameters of rag pipeline
"""
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
@ -823,6 +844,7 @@ class DraftRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource):
@api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -832,8 +854,7 @@ class DraftRagPipelineSecondStepApi(Resource):
"""
Get second step parameters of rag pipeline
"""
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
@ -845,8 +866,16 @@ class DraftRagPipelineSecondStepApi(Resource):
}
parser_wf_run = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource):
@api.expect(parser_wf_run)
@setup_required
@login_required
@account_initialization_required
@ -856,12 +885,7 @@ class RagPipelineWorkflowRunListApi(Resource):
"""
Get workflow run list
"""
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
args = parser_wf_run.parse_args()
rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
@ -961,8 +985,18 @@ class RagPipelineTransformApi(Resource):
return result
parser_var = (
reqparse.RequestParser()
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info", type=dict, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("start_node_title", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource):
@api.expect(parser_var)
@setup_required
@login_required
@account_initialization_required
@ -974,14 +1008,7 @@ class RagPipelineDatasourceVariableApi(Resource):
Set datasource variables
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info", type=dict, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("start_node_title", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_var.parse_args()
rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.set_datasource_variables(

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
from libs.login import current_user, login_required
@ -35,15 +35,18 @@ recommended_app_list_fields = {
}
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args")
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@api.expect(parser_apps)
@login_required
@account_initialization_required
@marshal_with(recommended_app_list_fields)
def get(self):
# language args
parser = reqparse.RequestParser().add_argument("language", type=str, location="args")
args = parser.parse_args()
args = parser_apps.parse_args()
language = args.get("language")
if language and language in languages:

View File

@ -22,7 +22,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper
from libs.login import current_user as current_user_
from libs.login import current_account_with_tenant
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@ -31,8 +31,6 @@ from .. import console_ns
logger = logging.getLogger(__name__)
current_user = current_user_._get_current_object() # type: ignore
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):
@ -40,6 +38,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
"""
Run workflow
"""
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
@ -53,7 +52,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
assert current_user is not None
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
@ -89,7 +87,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)

View File

@ -66,13 +66,7 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("api_endpoint", type=str, required=True, location="json")
.add_argument("api_key", type=str, required=True, location="json")
)
args = parser.parse_args()
args = api.payload
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension(
@ -125,13 +119,7 @@ class APIBasedExtensionDetailAPI(Resource):
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("api_endpoint", type=str, required=True, location="json")
.add_argument("api_key", type=str, required=True, location="json")
)
args = parser.parse_args()
args = api.payload
extension_data_from_db.name = args["name"]
extension_data_from_db.api_endpoint = args["api_endpoint"]

View File

@ -8,6 +8,7 @@ import services
from configs import dify_config
from constants import DOCUMENT_EXTENSIONS
from controllers.common.errors import (
BlockedFileExtensionError,
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
@ -39,6 +40,7 @@ class FileApi(Resource):
return {
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
"file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT,
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
@ -82,6 +84,8 @@ class FileApi(Resource):
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
raise BlockedFileExtensionError(blocked_extension_error.description)
return upload_file, 201

View File

@ -10,6 +10,7 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.console import api
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
@ -36,12 +37,15 @@ class RemoteFileInfoApi(Resource):
}
parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
@console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource):
@api.expect(parser_upload)
@marshal_with(file_fields_with_signed_url)
def post(self):
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
args = parser.parse_args()
args = parser_upload.parse_args()
url = args["url"]

View File

@ -49,6 +49,7 @@ class SetupApi(Resource):
"email": fields.String(required=True, description="Admin email address"),
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
"password": fields.String(required=True, description="Admin password"),
"language": fields.String(required=False, description="Admin language"),
},
)
)

View File

@ -2,7 +2,7 @@ from flask import request
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required
@ -16,6 +16,19 @@ def _validate_name(name):
return name
parser_tags = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=_validate_name,
)
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tags")
class TagListApi(Resource):
@setup_required
@ -30,6 +43,7 @@ class TagListApi(Resource):
return tags, 200
@api.expect(parser_tags)
@setup_required
@login_required
@account_initialization_required
@ -39,20 +53,7 @@ class TagListApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=_validate_name,
)
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
)
args = parser.parse_args()
args = parser_tags.parse_args()
tag = TagService.save_tags(args)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@ -60,8 +61,14 @@ class TagListApi(Resource):
return response, 200
parser_tag_id = reqparse.RequestParser().add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
@console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource):
@api.expect(parser_tag_id)
@setup_required
@login_required
@account_initialization_required
@ -72,10 +79,7 @@ class TagUpdateDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
args = parser.parse_args()
args = parser_tag_id.parse_args()
tag = TagService.update_tags(args, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@ -99,8 +103,17 @@ class TagUpdateDeleteApi(Resource):
return 204
parser_create = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
@api.expect(parser_create)
@setup_required
@login_required
@account_initialization_required
@ -110,26 +123,23 @@ class TagBindingCreateApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument(
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
)
.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
)
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
)
args = parser.parse_args()
args = parser_create.parse_args()
TagService.save_tag_binding(args)
return {"result": "success"}, 200
parser_remove = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
@api.expect(parser_remove)
@setup_required
@login_required
@account_initialization_required
@ -139,15 +149,7 @@ class TagBindingDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
)
args = parser.parse_args()
args = parser_remove.parse_args()
TagService.delete_tag_binding(args)
return {"result": "success"}, 200

View File

@ -11,16 +11,16 @@ from . import api, console_ns
logger = logging.getLogger(__name__)
parser = reqparse.RequestParser().add_argument(
"current_version", type=str, required=True, location="args", help="Current application version"
)
@console_ns.route("/version")
class VersionApi(Resource):
@api.doc("check_version_update")
@api.doc(description="Check for application version updates")
@api.expect(
api.parser().add_argument(
"current_version", type=str, required=True, location="args", help="Current application version"
)
)
@api.expect(parser)
@api.response(
200,
"Success",
@ -37,7 +37,6 @@ class VersionApi(Resource):
)
def get(self):
"""Check for application version updates"""
parser = reqparse.RequestParser().add_argument("current_version", type=str, required=True, location="args")
args = parser.parse_args()
check_update_url = dify_config.CHECK_UPDATE_URL

View File

@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
EmailChangeLimitError,
@ -43,8 +43,19 @@ from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
def _init_parser():
parser = reqparse.RequestParser()
if dify_config.EDITION == "CLOUD":
parser.add_argument("invitation_code", type=str, location="json")
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
"timezone", type=timezone, required=True, location="json"
)
return parser
@console_ns.route("/account/init")
class AccountInitApi(Resource):
@api.expect(_init_parser())
@setup_required
@login_required
def post(self):
@ -53,14 +64,7 @@ class AccountInitApi(Resource):
if account.status == "active":
raise AccountAlreadyInitedError()
parser = reqparse.RequestParser()
if dify_config.EDITION == "CLOUD":
parser.add_argument("invitation_code", type=str, location="json")
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
"timezone", type=timezone, required=True, location="json"
)
args = parser.parse_args()
args = _init_parser().parse_args()
if dify_config.EDITION == "CLOUD":
if not args["invitation_code"]:
@ -106,16 +110,19 @@ class AccountProfileApi(Resource):
return current_user
parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
@console_ns.route("/account/name")
class AccountNameApi(Resource):
@api.expect(parser_name)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_name.parse_args()
# Validate account name length
if len(args["name"]) < 3 or len(args["name"]) > 30:
@ -126,68 +133,80 @@ class AccountNameApi(Resource):
return updated_account
parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@api.expect(parser_avatar)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_avatar.parse_args()
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
return updated_account
parser_interface = reqparse.RequestParser().add_argument(
"interface_language", type=supported_language, required=True, location="json"
)
@console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource):
@api.expect(parser_interface)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"interface_language", type=supported_language, required=True, location="json"
)
args = parser.parse_args()
args = parser_interface.parse_args()
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
return updated_account
parser_theme = reqparse.RequestParser().add_argument(
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
)
@console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource):
@api.expect(parser_theme)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
)
args = parser.parse_args()
args = parser_theme.parse_args()
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
return updated_account
parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
@console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource):
@api.expect(parser_timezone)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_timezone.parse_args()
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
if args["timezone"] not in pytz.all_timezones:
@ -198,21 +217,24 @@ class AccountTimezoneApi(Resource):
return updated_account
parser_pw = (
reqparse.RequestParser()
.add_argument("password", type=str, required=False, location="json")
.add_argument("new_password", type=str, required=True, location="json")
.add_argument("repeat_new_password", type=str, required=True, location="json")
)
@console_ns.route("/account/password")
class AccountPasswordApi(Resource):
@api.expect(parser_pw)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("password", type=str, required=False, location="json")
.add_argument("new_password", type=str, required=True, location="json")
.add_argument("repeat_new_password", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_pw.parse_args()
if args["new_password"] != args["repeat_new_password"]:
raise RepeatPasswordNotMatchError()
@ -294,20 +316,23 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token}
parser_delete = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
)
@console_ns.route("/account/delete")
class AccountDeleteApi(Resource):
@api.expect(parser_delete)
@setup_required
@login_required
@account_initialization_required
def post(self):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_delete.parse_args()
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
raise InvalidAccountDeletionCodeError()
@ -317,16 +342,19 @@ class AccountDeleteApi(Resource):
return {"result": "success"}
parser_feedback = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("feedback", type=str, required=True, location="json")
)
@console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource):
@api.expect(parser_feedback)
@setup_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("feedback", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_feedback.parse_args()
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
@ -351,6 +379,14 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email)
parser_edu = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("institution", type=str, required=True, location="json")
.add_argument("role", type=str, required=True, location="json")
)
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
@ -360,6 +396,7 @@ class EducationApi(Resource):
"allow_refresh": fields.Boolean,
}
@api.expect(parser_edu)
@setup_required
@login_required
@account_initialization_required
@ -368,13 +405,7 @@ class EducationApi(Resource):
def post(self):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("institution", type=str, required=True, location="json")
.add_argument("role", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_edu.parse_args()
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
@ -394,6 +425,14 @@ class EducationApi(Resource):
return res
parser_autocomplete = (
reqparse.RequestParser()
.add_argument("keywords", type=str, required=True, location="args")
.add_argument("page", type=int, required=False, location="args", default=0)
.add_argument("limit", type=int, required=False, location="args", default=20)
)
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
@ -402,6 +441,7 @@ class EducationAutoCompleteApi(Resource):
"has_next": fields.Boolean,
}
@api.expect(parser_autocomplete)
@setup_required
@login_required
@account_initialization_required
@ -409,33 +449,30 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
parser = (
reqparse.RequestParser()
.add_argument("keywords", type=str, required=True, location="args")
.add_argument("page", type=int, required=False, location="args", default=0)
.add_argument("limit", type=int, required=False, location="args", default=20)
)
args = parser.parse_args()
args = parser_autocomplete.parse_args()
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
parser_change_email = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
.add_argument("phase", type=str, required=False, location="json")
.add_argument("token", type=str, required=False, location="json")
)
@console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource):
@api.expect(parser_change_email)
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
.add_argument("phase", type=str, required=False, location="json")
.add_argument("token", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_change_email.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@ -470,20 +507,23 @@ class ChangeEmailSendEmailApi(Resource):
return {"result": "success", "data": token}
parser_validity = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource):
@api.expect(parser_validity)
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_validity.parse_args()
user_email = args["email"]
@ -514,20 +554,23 @@ class ChangeEmailCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
parser_reset = (
reqparse.RequestParser()
.add_argument("new_email", type=email, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource):
@api.expect(parser_reset)
@enable_change_email
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("new_email", type=email, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_reset.parse_args()
if AccountService.is_account_in_freeze(args["new_email"]):
raise AccountInFreezeError()
@ -555,12 +598,15 @@ class ChangeEmailResetApi(Resource):
return updated_account
parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
@console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource):
@api.expect(parser_check)
@setup_required
def post(self):
parser = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
args = parser.parse_args()
args = parser_check.parse_args()
if AccountService.is_account_in_freeze(args["email"]):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args["email"]):

View File

@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse
import services
from configs import dify_config
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
EmailCodeError,
@ -48,22 +48,25 @@ class MemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
parser_invite = (
reqparse.RequestParser()
.add_argument("emails", type=list, required=True, location="json")
.add_argument("role", type=str, required=True, default="admin", location="json")
.add_argument("language", type=str, required=False, location="json")
)
@console_ns.route("/workspaces/current/members/invite-email")
class MemberInviteEmailApi(Resource):
"""Invite a new member by email."""
@api.expect(parser_invite)
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("members")
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("emails", type=list, required=True, location="json")
.add_argument("role", type=str, required=True, default="admin", location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_invite.parse_args()
invitee_emails = args["emails"]
invitee_role = args["role"]
@ -143,16 +146,19 @@ class MemberCancelInviteApi(Resource):
}, 200
parser_update = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
class MemberUpdateRoleApi(Resource):
"""Update member role."""
@api.expect(parser_update)
@setup_required
@login_required
@account_initialization_required
def put(self, member_id):
parser = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_update.parse_args()
new_role = args["role"]
if not TenantAccountRole.is_valid_role(new_role):
@ -191,17 +197,20 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
parser_send = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email."""
@api.expect(parser_send)
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
parser = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
args = parser_send.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@ -229,19 +238,22 @@ class SendOwnerTransferEmailApi(Resource):
return {"result": "success", "data": token}
parser_owner = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource):
@api.expect(parser_owner)
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_owner.parse_args()
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
@ -276,17 +288,20 @@ class OwnerTransferCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
parser_owner_transfer = reqparse.RequestParser().add_argument(
"token", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
class OwnerTransfer(Resource):
@api.expect(parser_owner_transfer)
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self, member_id):
parser = reqparse.RequestParser().add_argument(
"token", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_owner_transfer.parse_args()
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()

View File

@ -4,7 +4,7 @@ from flask import send_file
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -14,9 +14,19 @@ from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
parser_model = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=False,
nullable=True,
choices=[mt.value for mt in ModelType],
location="args",
)
@console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource):
@api.expect(parser_model)
@setup_required
@login_required
@account_initialization_required
@ -24,15 +34,7 @@ class ModelProviderListApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
parser = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=False,
nullable=True,
choices=[mt.value for mt in ModelType],
location="args",
)
args = parser.parse_args()
args = parser_model.parse_args()
model_provider_service = ModelProviderService()
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
@ -40,8 +42,30 @@ class ModelProviderListApi(Resource):
return jsonable_encoder({"data": provider_list})
parser_cred = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
)
parser_post_cred = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_put_cred = (
reqparse.RequestParser()
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_delete_cred = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
class ModelProviderCredentialApi(Resource):
@api.expect(parser_cred)
@setup_required
@login_required
@account_initialization_required
@ -49,10 +73,7 @@ class ModelProviderCredentialApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
parser = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
)
args = parser.parse_args()
args = parser_cred.parse_args()
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_provider_credential(
@ -61,6 +82,7 @@ class ModelProviderCredentialApi(Resource):
return {"credentials": credentials}
@api.expect(parser_post_cred)
@setup_required
@login_required
@account_initialization_required
@ -69,12 +91,7 @@ class ModelProviderCredentialApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_post_cred.parse_args()
model_provider_service = ModelProviderService()
@ -90,6 +107,7 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}, 201
@api.expect(parser_put_cred)
@setup_required
@login_required
@account_initialization_required
@ -98,13 +116,7 @@ class ModelProviderCredentialApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_put_cred.parse_args()
model_provider_service = ModelProviderService()
@ -121,6 +133,7 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}
@api.expect(parser_delete_cred)
@setup_required
@login_required
@account_initialization_required
@ -128,10 +141,8 @@ class ModelProviderCredentialApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_delete_cred.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
@ -141,8 +152,14 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}, 204
parser_switch = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
class ModelProviderCredentialSwitchApi(Resource):
@api.expect(parser_switch)
@setup_required
@login_required
@account_initialization_required
@ -150,10 +167,7 @@ class ModelProviderCredentialSwitchApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_switch.parse_args()
service = ModelProviderService()
service.switch_active_provider_credential(
@ -164,17 +178,20 @@ class ModelProviderCredentialSwitchApi(Resource):
return {"result": "success"}
parser_validate = reqparse.RequestParser().add_argument(
"credentials", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
class ModelProviderValidateApi(Resource):
@api.expect(parser_validate)
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"credentials", type=dict, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_validate.parse_args()
tenant_id = current_tenant_id
@ -218,8 +235,19 @@ class ModelProviderIconApi(Resource):
return send_file(io.BytesIO(icon), mimetype=mimetype)
parser_preferred = reqparse.RequestParser().add_argument(
"preferred_provider_type",
type=str,
required=True,
nullable=False,
choices=["system", "custom"],
location="json",
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource):
@api.expect(parser_preferred)
@setup_required
@login_required
@account_initialization_required
@ -230,15 +258,7 @@ class PreferredProviderTypeUpdateApi(Resource):
tenant_id = current_tenant_id
parser = reqparse.RequestParser().add_argument(
"preferred_provider_type",
type=str,
required=True,
nullable=False,
choices=["system", "custom"],
location="json",
)
args = parser.parse_args()
args = parser_preferred.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.switch_preferred_provider(

View File

@ -3,7 +3,7 @@ import logging
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
@ -16,23 +16,29 @@ from services.model_provider_service import ModelProviderService
logger = logging.getLogger(__name__)
parser_get_default = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
parser_post_default = reqparse.RequestParser().add_argument(
"model_settings", type=list, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource):
@api.expect(parser_get_default)
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
args = parser.parse_args()
args = parser_get_default.parse_args()
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
@ -41,6 +47,7 @@ class DefaultModelApi(Resource):
return jsonable_encoder({"data": default_model_entity})
@api.expect(parser_post_default)
@setup_required
@login_required
@account_initialization_required
@ -50,10 +57,7 @@ class DefaultModelApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"model_settings", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_post_default.parse_args()
model_provider_service = ModelProviderService()
model_settings = args["model_settings"]
for model_setting in model_settings:
@ -84,6 +88,35 @@ class DefaultModelApi(Resource):
return {"result": "success"}
parser_post_models = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
)
parser_delete_models = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
class ModelProviderModelApi(Resource):
@setup_required
@ -97,6 +130,7 @@ class ModelProviderModelApi(Resource):
return jsonable_encoder({"data": models})
@api.expect(parser_post_models)
@setup_required
@login_required
@account_initialization_required
@ -106,23 +140,7 @@ class ModelProviderModelApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_post_models.parse_args()
if args.get("config_from", "") == "custom-model":
if not args.get("credential_id"):
@ -160,6 +178,7 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 200
@api.expect(parser_delete_models)
@setup_required
@login_required
@account_initialization_required
@ -169,19 +188,7 @@ class ModelProviderModelApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
args = parser.parse_args()
args = parser_delete_models.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_model(
@ -191,29 +198,76 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 204
parser_get_credentials = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="args")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
)
parser_post_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
parser_put_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_delete_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
class ModelProviderModelCredentialApi(Resource):
@api.expect(parser_get_credentials)
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="args")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
)
args = parser.parse_args()
args = parser_get_credentials.parse_args()
model_provider_service = ModelProviderService()
current_credential = model_provider_service.get_model_credential(
@ -257,6 +311,7 @@ class ModelProviderModelCredentialApi(Resource):
}
)
@api.expect(parser_post_cred)
@setup_required
@login_required
@account_initialization_required
@ -266,21 +321,7 @@ class ModelProviderModelCredentialApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_post_cred.parse_args()
model_provider_service = ModelProviderService()
@ -304,6 +345,7 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}, 201
@api.expect(parser_put_cred)
@setup_required
@login_required
@account_initialization_required
@ -313,22 +355,7 @@ class ModelProviderModelCredentialApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_put_cred.parse_args()
model_provider_service = ModelProviderService()
@ -347,6 +374,7 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}
@api.expect(parser_delete_cred)
@setup_required
@login_required
@account_initialization_required
@ -355,20 +383,7 @@ class ModelProviderModelCredentialApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_delete_cred.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential(
@ -382,8 +397,24 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}, 204
parser_switch = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource):
@api.expect(parser_switch)
@setup_required
@login_required
@account_initialization_required
@ -392,20 +423,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_switch.parse_args()
service = ModelProviderService()
service.add_model_credential_to_model_list(
@ -418,29 +436,32 @@ class ModelProviderModelCredentialSwitchApi(Resource):
return {"result": "success"}
parser_model_enable_disable = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
)
class ModelProviderModelEnableApi(Resource):
@api.expect(parser_model_enable_disable)
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
args = parser.parse_args()
args = parser_model_enable_disable.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.enable_model(
@ -454,25 +475,14 @@ class ModelProviderModelEnableApi(Resource):
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
)
class ModelProviderModelDisableApi(Resource):
@api.expect(parser_model_enable_disable)
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
args = parser.parse_args()
args = parser_model_enable_disable.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.disable_model(
@ -482,28 +492,31 @@ class ModelProviderModelDisableApi(Resource):
return {"result": "success"}
parser_validate = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
class ModelProviderModelValidateApi(Resource):
@api.expect(parser_validate)
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_validate.parse_args()
model_provider_service = ModelProviderService()
@ -530,16 +543,19 @@ class ModelProviderModelValidateApi(Resource):
return response
parser_parameter = reqparse.RequestParser().add_argument(
"model", type=str, required=True, nullable=False, location="args"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource):
@api.expect(parser_parameter)
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
parser = reqparse.RequestParser().add_argument(
"model", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args()
args = parser_parameter.parse_args()
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()

View File

@ -5,7 +5,7 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
@ -37,19 +37,22 @@ class PluginDebuggingKeyApi(Resource):
raise ValueError(e)
parser_list = (
reqparse.RequestParser()
.add_argument("page", type=int, required=False, location="args", default=1)
.add_argument("page_size", type=int, required=False, location="args", default=256)
)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
@api.expect(parser_list)
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=False, location="args", default=1)
.add_argument("page_size", type=int, required=False, location="args", default=256)
)
args = parser.parse_args()
args = parser_list.parse_args()
try:
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
except PluginDaemonClientSideError as e:
@ -58,14 +61,17 @@ class PluginListApi(Resource):
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource):
@api.expect(parser_latest)
@setup_required
@login_required
@account_initialization_required
def post(self):
req = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
args = req.parse_args()
args = parser_latest.parse_args()
try:
versions = PluginService.list_latest_versions(args["plugin_ids"])
@ -75,16 +81,19 @@ class PluginListLatestVersionsApi(Resource):
return jsonable_encoder({"versions": versions})
parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource):
@api.expect(parser_ids)
@setup_required
@login_required
@account_initialization_required
def post(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
args = parser.parse_args()
args = parser_ids.parse_args()
try:
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
@ -94,16 +103,19 @@ class PluginListInstallationsFromIdsApi(Resource):
return jsonable_encoder({"plugins": plugins})
parser_icon = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="args")
.add_argument("filename", type=str, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource):
@api.expect(parser_icon)
@setup_required
def get(self):
req = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="args")
.add_argument("filename", type=str, required=True, location="args")
)
args = req.parse_args()
args = parser_icon.parse_args()
try:
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
@ -114,6 +126,25 @@ class PluginIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@console_ns.route("/workspaces/current/plugin/asset")
class PluginAssetApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
req = reqparse.RequestParser()
req.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
req.add_argument("file_name", type=str, required=True, location="args")
args = req.parse_args()
_, tenant_id = current_account_with_tenant()
try:
binary = PluginService.extract_asset(tenant_id, args["plugin_unique_identifier"], args["file_name"])
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
except PluginDaemonClientSideError as e:
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/upload/pkg")
class PluginUploadFromPkgApi(Resource):
@setup_required
@ -138,8 +169,17 @@ class PluginUploadFromPkgApi(Resource):
return jsonable_encoder(response)
parser_github = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource):
@api.expect(parser_github)
@setup_required
@login_required
@account_initialization_required
@ -147,13 +187,7 @@ class PluginUploadFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_github.parse_args()
try:
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
@ -187,19 +221,21 @@ class PluginUploadFromBundleApi(Resource):
return jsonable_encoder(response)
parser_pkg = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource):
@api.expect(parser_pkg)
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
args = parser.parse_args()
args = parser_pkg.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
@ -214,8 +250,18 @@ class PluginInstallFromPkgApi(Resource):
return jsonable_encoder(response)
parser_githubapi = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource):
@api.expect(parser_githubapi)
@setup_required
@login_required
@account_initialization_required
@ -223,14 +269,7 @@ class PluginInstallFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_githubapi.parse_args()
try:
response = PluginService.install_from_github(
@ -246,8 +285,14 @@ class PluginInstallFromGithubApi(Resource):
return jsonable_encoder(response)
parser_marketplace = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource):
@api.expect(parser_marketplace)
@setup_required
@login_required
@account_initialization_required
@ -255,10 +300,7 @@ class PluginInstallFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
args = parser.parse_args()
args = parser_marketplace.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
@ -273,19 +315,21 @@ class PluginInstallFromMarketplaceApi(Resource):
return jsonable_encoder(response)
parser_pkgapi = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource):
@api.expect(parser_pkgapi)
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
args = parser.parse_args()
args = parser_pkgapi.parse_args()
try:
return jsonable_encoder(
@ -300,8 +344,14 @@ class PluginFetchMarketplacePkgApi(Resource):
raise ValueError(e)
parser_fetch = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource):
@api.expect(parser_fetch)
@setup_required
@login_required
@account_initialization_required
@ -309,10 +359,7 @@ class PluginFetchManifestApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
args = parser.parse_args()
args = parser_fetch.parse_args()
try:
return jsonable_encoder(
@ -326,8 +373,16 @@ class PluginFetchManifestApi(Resource):
raise ValueError(e)
parser_tasks = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource):
@api.expect(parser_tasks)
@setup_required
@login_required
@account_initialization_required
@ -335,12 +390,7 @@ class PluginFetchInstallTasksApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
args = parser.parse_args()
args = parser_tasks.parse_args()
try:
return jsonable_encoder(
@ -410,8 +460,16 @@ class PluginDeleteInstallTaskItemApi(Resource):
raise ValueError(e)
parser_marketplace_api = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource):
@api.expect(parser_marketplace_api)
@setup_required
@login_required
@account_initialization_required
@ -419,12 +477,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_marketplace_api.parse_args()
try:
return jsonable_encoder(
@ -436,8 +489,19 @@ class PluginUpgradeFromMarketplaceApi(Resource):
raise ValueError(e)
parser_github_post = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource):
@api.expect(parser_github_post)
@setup_required
@login_required
@account_initialization_required
@ -445,15 +509,7 @@ class PluginUpgradeFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_github_post.parse_args()
try:
return jsonable_encoder(
@ -470,15 +526,20 @@ class PluginUpgradeFromGithubApi(Resource):
raise ValueError(e)
parser_uninstall = reqparse.RequestParser().add_argument(
"plugin_installation_id", type=str, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource):
@api.expect(parser_uninstall)
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
req = reqparse.RequestParser().add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args()
args = parser_uninstall.parse_args()
_, tenant_id = current_account_with_tenant()
@ -488,8 +549,16 @@ class PluginUninstallApi(Resource):
raise ValueError(e)
parser_change_post = (
reqparse.RequestParser()
.add_argument("install_permission", type=str, required=True, location="json")
.add_argument("debug_permission", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource):
@api.expect(parser_change_post)
@setup_required
@login_required
@account_initialization_required
@ -499,12 +568,7 @@ class PluginChangePermissionApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
req = (
reqparse.RequestParser()
.add_argument("install_permission", type=str, required=True, location="json")
.add_argument("debug_permission", type=str, required=True, location="json")
)
args = req.parse_args()
args = parser_change_post.parse_args()
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
@ -539,8 +603,20 @@ class PluginFetchPermissionApi(Resource):
)
parser_dynamic = (
reqparse.RequestParser()
.add_argument("plugin_id", type=str, required=True, location="args")
.add_argument("provider", type=str, required=True, location="args")
.add_argument("action", type=str, required=True, location="args")
.add_argument("parameter", type=str, required=True, location="args")
.add_argument("credential_id", type=str, required=False, location="args")
.add_argument("provider_type", type=str, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource):
@api.expect(parser_dynamic)
@setup_required
@login_required
@account_initialization_required
@ -552,25 +628,18 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
user_id = current_user.id
parser = (
reqparse.RequestParser()
.add_argument("plugin_id", type=str, required=True, location="args")
.add_argument("provider", type=str, required=True, location="args")
.add_argument("action", type=str, required=True, location="args")
.add_argument("parameter", type=str, required=True, location="args")
.add_argument("provider_type", type=str, required=True, location="args")
)
args = parser.parse_args()
args = parser_dynamic.parse_args()
try:
options = PluginParameterService.get_dynamic_select_options(
tenant_id,
user_id,
args["plugin_id"],
args["provider"],
args["action"],
args["parameter"],
args["provider_type"],
tenant_id=tenant_id,
user_id=user_id,
plugin_id=args["plugin_id"],
provider=args["provider"],
action=args["action"],
parameter=args["parameter"],
credential_id=args["credential_id"],
provider_type=args["provider_type"],
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
@ -578,8 +647,16 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
parser_change = (
reqparse.RequestParser()
.add_argument("permission", type=dict, required=True, location="json")
.add_argument("auto_upgrade", type=dict, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
@api.expect(parser_change)
@setup_required
@login_required
@account_initialization_required
@ -588,12 +665,7 @@ class PluginChangePreferencesApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
req = (
reqparse.RequestParser()
.add_argument("permission", type=dict, required=True, location="json")
.add_argument("auto_upgrade", type=dict, required=True, location="json")
)
args = req.parse_args()
args = parser_change.parse_args()
permission = args["permission"]
@ -673,8 +745,12 @@ class PluginFetchPreferencesApi(Resource):
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource):
@api.expect(parser_exclude)
@setup_required
@login_required
@account_initialization_required
@ -682,7 +758,26 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
# exclude one single plugin
_, tenant_id = current_account_with_tenant()
req = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
args = req.parse_args()
args = parser_exclude.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
@console_ns.route("/workspaces/current/plugin/readme")
class PluginReadmeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
parser.add_argument("language", type=str, required=False, location="args")
args = parser.parse_args()
return jsonable_encoder(
{
"readme": PluginService.fetch_plugin_readme(
tenant_id, args["plugin_unique_identifier"], args.get("language", "en-US")
)
}
)

View File

@ -6,29 +6,33 @@ from flask_restx import (
Resource,
reqparse,
)
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
)
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from extensions.ext_database import db
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import ToolProviderID
# from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
@ -42,12 +46,25 @@ def is_valid_url(url: str) -> bool:
try:
parsed = urlparse(url)
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
except Exception:
except (ValueError, TypeError):
# ValueError: Invalid URL format
# TypeError: url is not a string
return False
parser_tool = reqparse.RequestParser().add_argument(
"type",
type=str,
choices=["builtin", "model", "api", "workflow", "mcp"],
required=False,
nullable=True,
location="args",
)
@console_ns.route("/workspaces/current/tool-providers")
class ToolProviderListApi(Resource):
@api.expect(parser_tool)
@setup_required
@login_required
@account_initialization_required
@ -56,15 +73,7 @@ class ToolProviderListApi(Resource):
user_id = user.id
req = reqparse.RequestParser().add_argument(
"type",
type=str,
choices=["builtin", "model", "api", "workflow", "mcp"],
required=False,
nullable=True,
location="args",
)
args = req.parse_args()
args = parser_tool.parse_args()
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
@ -96,8 +105,14 @@ class ToolBuiltinProviderInfoApi(Resource):
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
parser_delete = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
class ToolBuiltinProviderDeleteApi(Resource):
@api.expect(parser_delete)
@setup_required
@login_required
@account_initialization_required
@ -106,10 +121,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
req = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
args = req.parse_args()
args = parser_delete.parse_args()
return BuiltinToolManageService.delete_builtin_tool_provider(
tenant_id,
@ -118,8 +130,17 @@ class ToolBuiltinProviderDeleteApi(Resource):
)
parser_add = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
.add_argument("type", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
class ToolBuiltinProviderAddApi(Resource):
@api.expect(parser_add)
@setup_required
@login_required
@account_initialization_required
@ -128,13 +149,7 @@ class ToolBuiltinProviderAddApi(Resource):
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
.add_argument("type", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_add.parse_args()
if args["type"] not in CredentialType.values():
raise ValueError(f"Invalid credential type: {args['type']}")
@ -149,8 +164,17 @@ class ToolBuiltinProviderAddApi(Resource):
)
parser_update = (
reqparse.RequestParser()
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
class ToolBuiltinProviderUpdateApi(Resource):
@api.expect(parser_update)
@setup_required
@login_required
@account_initialization_required
@ -162,14 +186,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_update.parse_args()
result = BuiltinToolManageService.update_builtin_tool_provider(
user_id=user_id,
@ -207,8 +224,22 @@ class ToolBuiltinProviderIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
parser_api_add = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/api/add")
class ToolApiProviderAddApi(Resource):
@api.expect(parser_api_add)
@setup_required
@login_required
@account_initialization_required
@ -220,19 +251,7 @@ class ToolApiProviderAddApi(Resource):
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_api_add.parse_args()
return ApiToolManageService.create_api_tool_provider(
user_id,
@ -248,8 +267,12 @@ class ToolApiProviderAddApi(Resource):
)
parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
@console_ns.route("/workspaces/current/tool-provider/api/remote")
class ToolApiProviderGetRemoteSchemaApi(Resource):
@api.expect(parser_remote)
@setup_required
@login_required
@account_initialization_required
@ -258,9 +281,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
args = parser_remote.parse_args()
return ApiToolManageService.get_api_tool_provider_remote_schema(
user_id,
@ -269,8 +290,14 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
)
parser_tools = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
)
@console_ns.route("/workspaces/current/tool-provider/api/tools")
class ToolApiProviderListToolsApi(Resource):
@api.expect(parser_tools)
@setup_required
@login_required
@account_initialization_required
@ -279,11 +306,7 @@ class ToolApiProviderListToolsApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args()
args = parser_tools.parse_args()
return jsonable_encoder(
ApiToolManageService.list_api_tool_provider_tools(
@ -294,8 +317,23 @@ class ToolApiProviderListToolsApi(Resource):
)
parser_api_update = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/api/update")
class ToolApiProviderUpdateApi(Resource):
@api.expect(parser_api_update)
@setup_required
@login_required
@account_initialization_required
@ -307,20 +345,7 @@ class ToolApiProviderUpdateApi(Resource):
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_api_update.parse_args()
return ApiToolManageService.update_api_tool_provider(
user_id,
@ -337,8 +362,14 @@ class ToolApiProviderUpdateApi(Resource):
)
parser_api_delete = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/api/delete")
class ToolApiProviderDeleteApi(Resource):
@api.expect(parser_api_delete)
@setup_required
@login_required
@account_initialization_required
@ -350,11 +381,7 @@ class ToolApiProviderDeleteApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_api_delete.parse_args()
return ApiToolManageService.delete_api_tool_provider(
user_id,
@ -363,8 +390,12 @@ class ToolApiProviderDeleteApi(Resource):
)
parser_get = reqparse.RequestParser().add_argument("provider", type=str, required=True, nullable=False, location="args")
@console_ns.route("/workspaces/current/tool-provider/api/get")
class ToolApiProviderGetApi(Resource):
@api.expect(parser_get)
@setup_required
@login_required
@account_initialization_required
@ -373,11 +404,7 @@ class ToolApiProviderGetApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args()
args = parser_get.parse_args()
return ApiToolManageService.get_api_tool_provider(
user_id,
@ -401,40 +428,44 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
)
parser_schema = reqparse.RequestParser().add_argument(
"schema", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/api/schema")
class ToolApiProviderSchemaApi(Resource):
@api.expect(parser_schema)
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser().add_argument(
"schema", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_schema.parse_args()
return ApiToolManageService.parser_api_schema(
schema=args["schema"],
)
parser_pre = (
reqparse.RequestParser()
.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
class ToolApiProviderPreviousTestApi(Resource):
@api.expect(parser_pre)
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_pre.parse_args()
_, current_tenant_id = current_account_with_tenant()
return ApiToolManageService.test_api_tool_preview(
current_tenant_id,
@ -447,8 +478,22 @@ class ToolApiProviderPreviousTestApi(Resource):
)
parser_create = (
reqparse.RequestParser()
.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
class ToolWorkflowProviderCreateApi(Resource):
@api.expect(parser_create)
@setup_required
@login_required
@account_initialization_required
@ -460,19 +505,7 @@ class ToolWorkflowProviderCreateApi(Resource):
user_id = user.id
reqparser = (
reqparse.RequestParser()
.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
args = reqparser.parse_args()
args = parser_create.parse_args()
return WorkflowToolManageService.create_workflow_tool(
user_id=user_id,
@ -488,8 +521,22 @@ class ToolWorkflowProviderCreateApi(Resource):
)
parser_workflow_update = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
class ToolWorkflowProviderUpdateApi(Resource):
@api.expect(parser_workflow_update)
@setup_required
@login_required
@account_initialization_required
@ -501,19 +548,7 @@ class ToolWorkflowProviderUpdateApi(Resource):
user_id = user.id
reqparser = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
args = reqparser.parse_args()
args = parser_workflow_update.parse_args()
if not args["workflow_tool_id"]:
raise ValueError("incorrect workflow_tool_id")
@ -532,8 +567,14 @@ class ToolWorkflowProviderUpdateApi(Resource):
)
parser_workflow_delete = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
class ToolWorkflowProviderDeleteApi(Resource):
@api.expect(parser_workflow_delete)
@setup_required
@login_required
@account_initialization_required
@ -545,11 +586,7 @@ class ToolWorkflowProviderDeleteApi(Resource):
user_id = user.id
reqparser = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
)
args = reqparser.parse_args()
args = parser_workflow_delete.parse_args()
return WorkflowToolManageService.delete_workflow_tool(
user_id,
@ -558,8 +595,16 @@ class ToolWorkflowProviderDeleteApi(Resource):
)
parser_wf_get = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
)
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
class ToolWorkflowProviderGetApi(Resource):
@api.expect(parser_wf_get)
@setup_required
@login_required
@account_initialization_required
@ -568,13 +613,7 @@ class ToolWorkflowProviderGetApi(Resource):
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
)
args = parser.parse_args()
args = parser_wf_get.parse_args()
if args.get("workflow_tool_id"):
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
@ -594,8 +633,14 @@ class ToolWorkflowProviderGetApi(Resource):
return jsonable_encoder(tool)
parser_wf_tools = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
)
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
class ToolWorkflowProviderListToolApi(Resource):
@api.expect(parser_wf_tools)
@setup_required
@login_required
@account_initialization_required
@ -604,11 +649,7 @@ class ToolWorkflowProviderListToolApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
)
args = parser.parse_args()
args = parser_wf_tools.parse_args()
return jsonable_encoder(
WorkflowToolManageService.list_single_workflow_tools(
@ -784,32 +825,40 @@ class ToolOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_default_cred = reqparse.RequestParser().add_argument(
"id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
class ToolBuiltinProviderSetDefaultApi(Resource):
@api.expect(parser_default_cred)
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
args = parser_default_cred.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
)
parser_custom = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
class ToolOAuthCustomClient(Resource):
@api.expect(parser_custom)
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_custom.parse_args()
user, tenant_id = current_account_with_tenant()
@ -872,141 +921,185 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
)
parser_mcp = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
parser_mcp_put = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
parser_mcp_delete = reqparse.RequestParser().add_argument(
"provider_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/mcp")
class ToolProviderMCPApi(Resource):
@api.expect(parser_mcp)
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
.add_argument("sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300)
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args()
args = parser_mcp.parse_args()
user, tenant_id = current_account_with_tenant()
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
# Parse and validate models
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
# Create provider
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
tenant_id=tenant_id,
user_id=user.id,
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
user_id=user.id,
server_identifier=args["server_identifier"],
timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"],
headers=args["headers"],
configuration=configuration,
authentication=authentication,
)
)
return jsonable_encoder(result)
@api.expect(parser_mcp_put)
@setup_required
@login_required
@account_initialization_required
def put(self):
parser = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("timeout", type=float, required=False, nullable=True, location="json")
.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
.add_argument("headers", type=dict, required=False, nullable=True, location="json")
)
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
pass
else:
raise ValueError("Server URL is not valid.")
args = parser_mcp_put.parse_args()
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.update_mcp_provider(
tenant_id=current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
headers=args.get("headers"),
)
return {"result": "success"}
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
validation_result = None
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
validation_result = service.validate_server_url_change(
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
)
# No need to check for errors here, exceptions will be raised directly
# Step 2: Perform database update in a transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
tenant_id=current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
headers=args["headers"],
configuration=configuration,
authentication=authentication,
validation_result=validation_result,
)
return {"result": "success"}
@api.expect(parser_mcp_delete)
@setup_required
@login_required
@account_initialization_required
def delete(self):
parser = reqparse.RequestParser().add_argument(
"provider_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_mcp_delete.parse_args()
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
parser_auth = (
reqparse.RequestParser()
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
class ToolMCPAuthApi(Resource):
@api.expect(parser_auth)
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_auth.parse_args()
provider_id = args["provider_id"]
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
try:
with MCPClient(
provider.decrypted_server_url,
provider_id,
tenant_id,
authed=False,
authorization_code=args["authorization_code"],
for_list=True,
headers=provider.decrypted_headers,
timeout=provider.timeout,
sse_read_timeout=provider.sse_read_timeout,
):
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
credentials=provider.decrypted_credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
raise ValueError("provider not found")
# Convert to entity
provider_entity = db_provider.to_entity()
server_url = provider_entity.decrypt_server_url()
headers = provider_entity.decrypt_authentication()
# Try to connect without active transaction
try:
# Use MCPClientWithAuthRetry to handle authentication automatically
with MCPClient(
server_url=server_url,
headers=headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Update credentials in new transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider_credentials(
provider_id=provider_id,
tenant_id=tenant_id,
credentials=provider_entity.credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
try:
auth_result = auth(provider_entity, args.get("authorization_code"))
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result)
return response
except MCPRefreshTokenError as e:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except MCPError as e:
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
credentials={},
authed=False,
)
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@ -1017,8 +1110,10 @@ class ToolMCPDetailApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@console_ns.route("/workspaces/current/tools/mcp")
@ -1029,9 +1124,12 @@ class ToolMCPListAllApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
# Skip sensitive data decryption for list view to improve performance
tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False)
return [tool.to_dict() for tool in tools]
return [tool.to_dict() for tool in tools]
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
@ -1041,23 +1139,38 @@ class ToolMCPUpdateApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
tools = service.list_provider_tools(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
parser_cb = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, nullable=False, location="args")
.add_argument("state", type=str, required=True, nullable=False, location="args")
)
@console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource):
@api.expect(parser_cb)
def get(self):
parser = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, nullable=False, location="args")
.add_argument("state", type=str, required=True, nullable=False, location="args")
)
args = parser.parse_args()
args = parser_cb.parse_args()
state_key = args["state"]
authorization_code = args["code"]
handle_callback(state_key, authorization_code)
# Create service instance for handle_callback
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
# handle_callback now returns state data and tokens
state_data, tokens = handle_callback(state_key, authorization_code)
# Save tokens using the service layer
mcp_service.save_oauth_data(
state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

View File

@ -0,0 +1,592 @@
import logging
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.trigger.entities.entities import SubscriptionBuilderUpdater
from core.trigger.trigger_manager import TriggerManager
from extensions.ext_database import db
from libs.login import current_user, login_required
from models.account import Account
from models.provider_ids import TriggerProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
logger = logging.getLogger(__name__)
class TriggerProviderIconApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return TriggerManager.get_trigger_plugin_icon(tenant_id=user.current_tenant_id, provider_id=provider)
class TriggerProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List all trigger providers for the current tenant"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
class TriggerProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Get info for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return jsonable_encoder(
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
)
class TriggerSubscriptionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
return jsonable_encoder(
TriggerProviderService.list_trigger_provider_subscriptions(
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
)
)
except ValueError as e:
return jsonable_encoder({"error": str(e)}), 404
except Exception as e:
logger.exception("Error listing trigger providers", exc_info=e)
raise
class TriggerSubscriptionBuilderCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
"""Add a new subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
credential_type=credential_type,
)
return jsonable_encoder({"subscription_builder": subscription_builder})
except Exception as e:
logger.exception("Error adding provider credential", exc_info=e)
raise
class TriggerSubscriptionBuilderGetApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get a subscription instance for a trigger provider"""
return jsonable_encoder(
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
)
class TriggerSubscriptionBuilderVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
# Use atomic update_and_verify to prevent race conditions
return TriggerSubscriptionBuilderService.update_and_verify_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=args.get("credentials", None),
),
)
except Exception as e:
logger.exception("Error verifying provider credential", exc_info=e)
raise ValueError(str(e)) from e
class TriggerSubscriptionBuilderUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Update a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
parser = reqparse.RequestParser()
# The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
credentials=args.get("credentials", None),
),
)
)
except Exception as e:
logger.exception("Error updating provider credential", exc_info=e)
raise
class TriggerSubscriptionBuilderLogsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get the request logs for a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
except Exception as e:
logger.exception("Error getting request logs for subscription builder", exc_info=e)
raise
class TriggerSubscriptionBuilderBuildApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
# The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
# Use atomic update_and_build to prevent race conditions
TriggerSubscriptionBuilderService.update_and_build_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
),
)
return 200
except Exception as e:
logger.exception("Error building provider credential", exc_info=e)
raise ValueError(str(e)) from e
class TriggerSubscriptionDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, subscription_id: str):
"""Delete a subscription instance"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
with Session(db.engine) as session:
# Delete trigger provider subscription
TriggerProviderService.delete_trigger_provider(
session=session,
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
# Delete plugin triggers
TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription(
session=session,
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
session.commit()
return {"result": "success"}
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error deleting provider credential", exc_info=e)
raise
class TriggerOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Initiate OAuth authorization flow for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
tenant_id = user.current_tenant_id
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
tenant_id=tenant_id,
provider_id=provider_id,
)
if oauth_client_params is None:
raise NotFoundError("No OAuth client configuration found for this trigger provider")
# Create subscription builder
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=tenant_id,
user_id=user.id,
provider_id=provider_id,
credential_type=CredentialType.OAUTH2,
)
# Create OAuth handler and proxy context
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=user.id,
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider_name,
extra_data={
"subscription_builder_id": subscription_builder.id,
},
)
# Build redirect URI for callback
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
# Get authorization URL
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
)
# Create response with cookie
response = make_response(
jsonable_encoder(
{
"authorization_url": authorization_url_response.authorization_url,
"subscription_builder_id": subscription_builder.id,
"subscription_builder": subscription_builder,
}
)
)
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
except Exception as e:
logger.exception("Error initiating OAuth flow", exc_info=e)
raise
class TriggerOAuthCallbackApi(Resource):
@setup_required
def get(self, provider):
"""Handle OAuth callback for trigger provider"""
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
# Use and validate proxy context
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
# Parse provider ID
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
user_id = context.get("user_id")
tenant_id = context.get("tenant_id")
subscription_builder_id = context.get("subscription_builder_id")
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
tenant_id=tenant_id,
provider_id=provider_id,
)
if oauth_client_params is None:
raise Forbidden("No OAuth client configuration found for this trigger provider")
# Get OAuth credentials from callback
oauth_handler = OAuthHandler()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
)
credentials = credentials_response.credentials
expires_at = credentials_response.expires_at
if not credentials:
raise ValueError("Failed to get OAuth credentials from the provider.")
# Update subscription builder
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=tenant_id,
provider_id=provider_id,
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=credentials,
credential_expires_at=expires_at,
),
)
# Redirect to OAuth callback page
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class TriggerOAuthClientManageApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Get OAuth client configuration for a provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
provider_id = TriggerProviderID(provider)
# Get custom OAuth client params if exists
custom_params = TriggerProviderService.get_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
# Check if custom client is enabled
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
system_client_exists = TriggerProviderService.is_oauth_system_client_exists(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
return jsonable_encoder(
{
"configured": bool(custom_params or system_client_exists),
"system_configured": system_client_exists,
"custom_configured": bool(custom_params),
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
"custom_enabled": is_custom_enabled,
"redirect_uri": redirect_uri,
"params": custom_params or {},
}
)
except Exception as e:
logger.exception("Error getting OAuth client", exc_info=e)
raise
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
"""Configure custom OAuth client for a provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
client_params=args.get("client_params"),
enabled=args.get("enabled"),
)
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error configuring OAuth client", exc_info=e)
raise
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
"""Remove custom OAuth client configuration"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.delete_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error removing OAuth client", exc_info=e)
raise
# Trigger Subscription
api.add_resource(TriggerProviderIconApi, "/workspaces/current/trigger-provider/<path:provider>/icon")
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
api.add_resource(
TriggerSubscriptionDeleteApi,
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
)
# Trigger Subscription Builder
api.add_resource(
TriggerSubscriptionBuilderCreateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
)
api.add_resource(
TriggerSubscriptionBuilderGetApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderUpdateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderVerifyApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderBuildApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderLogsApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
)
# OAuth
api.add_resource(
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
)
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")

View File

@ -13,7 +13,7 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.admin import admin_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
@ -21,6 +21,7 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
setup_required,
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
@ -83,7 +84,7 @@ class TenantListApi(Resource):
"name": tenant.name,
"status": tenant.status,
"created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
"current": tenant.id == current_tenant_id if current_tenant_id else False,
}
@ -149,15 +150,18 @@ class TenantApi(Resource):
return WorkspaceService.get_tenant_info(tenant), 200
parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
@console_ns.route("/workspaces/switch")
class SwitchWorkspaceApi(Resource):
@api.expect(parser_switch)
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_switch.parse_args()
# check if tenant_id is valid, 403 if not
try:
@ -241,16 +245,19 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201
parser_info = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
@console_ns.route("/workspaces/info")
class WorkspaceInfoApi(Resource):
@api.expect(parser_info)
@setup_required
@login_required
@account_initialization_required
# Change workspace name
def post(self):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_info.parse_args()
if not current_tenant_id:
raise ValueError("No current tenant")

View File

@ -10,6 +10,7 @@ from flask import abort, request
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant
@ -133,7 +134,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
if features.billing.subscription.plan == CloudPlan.SANDBOX:
abort(
403,
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",

View File

@ -14,10 +14,25 @@ from services.file_service import FileService
@files_ns.route("/<uuid:file_id>/image-preview")
class ImagePreviewApi(Resource):
"""
Deprecated
"""
"""Deprecated endpoint for retrieving image previews."""
@files_ns.doc("get_image_preview")
@files_ns.doc(description="Retrieve a signed image preview for a file")
@files_ns.doc(
params={
"file_id": "ID of the file to preview",
"timestamp": "Unix timestamp used in the signature",
"nonce": "Random string used in the signature",
"sign": "HMAC signature verifying the request",
}
)
@files_ns.doc(
responses={
200: "Image preview returned successfully",
400: "Missing or invalid signature parameters",
415: "Unsupported file type",
}
)
def get(self, file_id):
file_id = str(file_id)
@ -43,6 +58,25 @@ class ImagePreviewApi(Resource):
@files_ns.route("/<uuid:file_id>/file-preview")
class FilePreviewApi(Resource):
@files_ns.doc("get_file_preview")
@files_ns.doc(description="Download a file preview or attachment using signed parameters")
@files_ns.doc(
params={
"file_id": "ID of the file to preview",
"timestamp": "Unix timestamp used in the signature",
"nonce": "Random string used in the signature",
"sign": "HMAC signature verifying the request",
"as_attachment": "Whether to download the file as an attachment",
}
)
@files_ns.doc(
responses={
200: "File stream returned successfully",
400: "Missing or invalid signature parameters",
404: "File not found",
415: "Unsupported file type",
}
)
def get(self, file_id):
file_id = str(file_id)
@ -101,6 +135,20 @@ class FilePreviewApi(Resource):
@files_ns.route("/workspaces/<uuid:workspace_id>/webapp-logo")
class WorkspaceWebappLogoApi(Resource):
@files_ns.doc("get_workspace_webapp_logo")
@files_ns.doc(description="Fetch the custom webapp logo for a workspace")
@files_ns.doc(
params={
"workspace_id": "Workspace identifier",
}
)
@files_ns.doc(
responses={
200: "Logo returned successfully",
404: "Webapp logo not configured",
415: "Unsupported file type",
}
)
def get(self, workspace_id):
workspace_id = str(workspace_id)

View File

@ -13,6 +13,26 @@ from extensions.ext_database import db as global_db
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
class ToolFileApi(Resource):
@files_ns.doc("get_tool_file")
@files_ns.doc(description="Download a tool file by ID using signed parameters")
@files_ns.doc(
params={
"file_id": "Tool file identifier",
"extension": "Expected file extension",
"timestamp": "Unix timestamp used in the signature",
"nonce": "Random string used in the signature",
"sign": "HMAC signature verifying the request",
"as_attachment": "Whether to download the file as an attachment",
}
)
@files_ns.doc(
responses={
200: "Tool file stream returned successfully",
403: "Forbidden - invalid signature",
404: "File not found",
415: "Unsupported file type",
}
)
def get(self, file_id, extension):
file_id = str(file_id)

View File

@ -592,7 +592,7 @@ class DocumentApi(DatasetApiResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
@ -625,7 +625,7 @@ class DocumentApi(DatasetApiResource):
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"created_at": int(document.created_at.timestamp()),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,

View File

@ -2,6 +2,7 @@ from flask import request
from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import (
@ -107,6 +108,10 @@ class SegmentApi(DatasetApiResource):
# validate args
args = segment_create_parser.parse_args()
if args["segments"] is not None:
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(args["segments"]) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
for args_item in args["segments"]:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)

View File

@ -13,13 +13,15 @@ from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App, DefaultEndUserSessionID, EndUser
from models.model import ApiToken, App
from services.end_user_service import EndUserService
from services.feature_service import FeatureService
P = ParamSpec("P")
@ -66,6 +68,7 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
kwargs["app_model"] = app_model
# If caller needs end-user context, attach EndUser to current_user
if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get("user")
@ -74,7 +77,6 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get("user")
else:
# use default-user
user_id = None
if not user_id and fetch_user_arg.required:
@ -83,12 +85,34 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
if user_id:
user_id = str(user_id)
end_user = create_or_update_end_user_for_user_id(app_model, user_id)
end_user = EndUserService.get_or_create_end_user(app_model, user_id)
kwargs["end_user"] = end_user
# Set EndUser as current logged-in user for flask_login.current_user
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
else:
# For service API without end-user context, ensure an Account is logged in
# so services relying on current_account_with_tenant() work correctly.
tenant_owner_info = (
db.session.query(Tenant, Account)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.join(Account, TenantAccountJoin.account_id == Account.id)
.where(
Tenant.id == app_model.tenant_id,
TenantAccountJoin.role == "owner",
Tenant.status == TenantStatus.NORMAL,
)
.one_or_none()
)
if tenant_owner_info:
tenant_model, account = tenant_owner_info
account.current_tenant = tenant_model
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant owner account not found or tenant is not active.")
return view_func(*args, **kwargs)
@ -138,7 +162,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
if features.billing.subscription.plan == CloudPlan.SANDBOX:
raise Forbidden(
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
)
@ -308,39 +332,6 @@ def validate_and_get_api_token(scope: str | None = None):
return api_token
def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None = None) -> EndUser:
"""
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
with Session(db.engine, expire_on_commit=False) as session:
end_user = (
session.query(EndUser)
.where(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == "service_api",
)
.first()
)
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID,
session_id=user_id,
)
session.add(end_user)
session.commit()
return end_user
class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token]

View File

@ -0,0 +1,12 @@
from flask import Blueprint
# Create trigger blueprint
bp = Blueprint("trigger", __name__, url_prefix="/triggers")
# Import routes after blueprint creation to avoid circular imports
from . import trigger, webhook
__all__ = [
"trigger",
"webhook",
]

View File

@ -0,0 +1,43 @@
import logging
import re
from flask import jsonify, request
from werkzeug.exceptions import NotFound
from controllers.trigger import bp
from services.trigger.trigger_service import TriggerService
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
logger = logging.getLogger(__name__)
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
UUID_MATCHER = re.compile(UUID_PATTERN)
@bp.route("/plugin/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
def trigger_endpoint(endpoint_id: str):
"""
Handle endpoint trigger calls.
"""
# endpoint_id must be UUID
if not UUID_MATCHER.match(endpoint_id):
raise NotFound("Invalid endpoint ID")
handling_chain = [
TriggerService.process_endpoint,
TriggerSubscriptionBuilderService.process_builder_validation_endpoint,
]
response = None
try:
for handler in handling_chain:
response = handler(endpoint_id, request)
if response:
break
if not response:
logger.error("Endpoint not found for {endpoint_id}")
return jsonify({"error": "Endpoint not found"}), 404
return response
except ValueError as e:
return jsonify({"error": "Endpoint processing failed", "message": str(e)}), 400
except Exception:
logger.exception("Webhook processing failed for {endpoint_id}")
return jsonify({"error": "Internal server error"}), 500

View File

@ -0,0 +1,105 @@
import logging
import time
from flask import jsonify
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
from controllers.trigger import bp
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
from services.trigger.webhook_service import WebhookService
logger = logging.getLogger(__name__)
def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
"""Fetch trigger context, extract request data, and validate payload using unified processing.
Args:
webhook_id: The webhook ID to process
is_debug: If True, skip status validation for debug mode
"""
webhook_trigger, workflow, node_config = WebhookService.get_webhook_trigger_and_workflow(
webhook_id, is_debug=is_debug
)
try:
# Use new unified extraction and validation
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
return webhook_trigger, workflow, node_config, webhook_data, None
except ValueError as e:
# Fall back to raw extraction for error reporting
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
return webhook_trigger, workflow, node_config, webhook_data, str(e)
@bp.route("/webhook/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
def handle_webhook(webhook_id: str):
"""
Handle webhook trigger calls.
This endpoint receives webhook calls and processes them according to the
configured webhook trigger settings.
"""
try:
webhook_trigger, workflow, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id)
if error:
return jsonify({"error": "Bad Request", "message": error}), 400
# Process webhook call (send to Celery)
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
# Return configured response
response_data, status_code = WebhookService.generate_webhook_response(node_config)
return jsonify(response_data), status_code
except ValueError as e:
raise NotFound(str(e))
except RequestEntityTooLarge:
raise
except Exception as e:
logger.exception("Webhook processing failed for %s", webhook_id)
return jsonify({"error": "Internal server error", "message": str(e)}), 500
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
def handle_webhook_debug(webhook_id: str):
"""Handle webhook debug calls without triggering production workflow execution."""
try:
webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True)
if error:
return jsonify({"error": "Bad Request", "message": error}), 400
workflow_inputs = WebhookService.build_workflow_inputs(webhook_data)
# Generate pool key and dispatch debug event
pool_key: str = build_webhook_pool_key(
tenant_id=webhook_trigger.tenant_id,
app_id=webhook_trigger.app_id,
node_id=webhook_trigger.node_id,
)
event = WebhookDebugEvent(
request_id=f"webhook_debug_{webhook_trigger.webhook_id}_{int(time.time() * 1000)}",
timestamp=int(time.time()),
node_id=webhook_trigger.node_id,
payload={
"inputs": workflow_inputs,
"webhook_data": webhook_data,
"method": webhook_data.get("method"),
},
)
TriggerDebugEventBus.dispatch(
tenant_id=webhook_trigger.tenant_id,
event=event,
pool_key=pool_key,
)
response_data, status_code = WebhookService.generate_webhook_response(node_config)
return jsonify(response_data), status_code
except ValueError as e:
raise NotFound(str(e))
except RequestEntityTooLarge:
raise
except Exception as e:
logger.exception("Webhook debug processing failed for %s", webhook_id)
return jsonify({"error": "Internal server error", "message": "An internal error has occurred."}), 500

View File

@ -88,12 +88,6 @@ class AudioApi(WebApiResource):
@web_ns.route("/text-to-audio")
class TextApi(WebApiResource):
text_to_audio_response_fields = {
"audio_url": fields.String,
"duration": fields.Float,
}
@marshal_with(text_to_audio_response_fields)
@web_ns.doc("Text to Audio")
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
@web_ns.doc(

View File

@ -17,8 +17,8 @@ from libs.helper import email
from libs.passport import PassportService
from libs.password import valid_password
from libs.token import (
clear_access_token_from_cookie,
extract_access_token,
clear_webapp_access_token_from_cookie,
extract_webapp_access_token,
)
from services.account_service import AccountService
from services.app_service import AppService
@ -81,7 +81,7 @@ class LoginStatusApi(Resource):
)
def get(self):
app_code = request.args.get("app_code")
token = extract_access_token(request)
token = extract_webapp_access_token(request)
if not app_code:
return {
"logged_in": bool(token),
@ -128,7 +128,7 @@ class LogoutApi(Resource):
response = make_response({"result": "success"})
# enterprise SSO sets same site to None in https deployment
# so we need to logout by calling api
clear_access_token_from_cookie(response, samesite="None")
clear_webapp_access_token_from_cookie(response, samesite="None")
return response

View File

@ -12,7 +12,7 @@ from controllers.web import web_ns
from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
from libs.token import extract_access_token
from libs.token import extract_webapp_access_token
from models.model import App, EndUser, Site
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
@ -35,7 +35,7 @@ class PassportResource(Resource):
system_features = FeatureService.get_system_features()
app_code = request.headers.get(HEADER_NAME_APP_CODE)
user_id = request.args.get("user_id")
access_token = extract_access_token(request)
access_token = extract_webapp_access_token(request)
if app_code is None:
raise Unauthorized("X-App-Code header is missing.")
if system_features.webapp_auth.enabled:

View File

@ -1,6 +1,6 @@
import logging
import time
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any, cast
from sqlalchemy import select
@ -25,6 +25,7 @@ from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -61,11 +62,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app: App,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
):
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
graph_engine_layers=graph_engine_layers,
)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
@ -195,6 +198,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)
generator = workflow_entry.run()

View File

@ -1,3 +1,4 @@
import json
import logging
import re
import time
@ -60,6 +61,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
@ -391,6 +393,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if should_direct_answer:
return
current_time = time.perf_counter()
if self._task_state.first_token_time is None and delta_text.strip():
self._task_state.first_token_time = current_time
self._task_state.is_streaming_response = True
if delta_text.strip():
self._task_state.last_token_time = current_time
# Only publish tts message at text chunk streaming
if tts_publisher and queue_message:
tts_publisher.publish(queue_message)
@ -772,7 +782,33 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
message.answer = answer_text
message.updated_at = naive_utc_now()
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
message.message_metadata = self._task_state.metadata.model_dump_json()
# Set usage first before dumping metadata
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
message.message_tokens = usage.prompt_tokens
message.message_unit_price = usage.prompt_unit_price
message.message_price_unit = usage.prompt_price_unit
message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.metadata.usage = usage
else:
usage = LLMUsage.empty_usage()
self._task_state.metadata.usage = usage
# Add streaming metrics to usage if available
if self._task_state.is_streaming_response and self._task_state.first_token_time:
start_time = self._base_task_pipeline.start_at
first_token_time = self._task_state.first_token_time
last_token_time = self._task_state.last_token_time or first_token_time
usage.time_to_first_token = round(first_token_time - start_time, 3)
usage.time_to_generate = round(last_token_time - first_token_time, 3)
metadata = self._task_state.metadata.model_dump()
message.message_metadata = json.dumps(jsonable_encoder(metadata))
message_files = [
MessageFile(
message_id=message.id,
@ -790,20 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
]
session.add_all(message_files)
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
message.message_tokens = usage.prompt_tokens
message.message_unit_price = usage.prompt_unit_price
message.message_price_unit = usage.prompt_price_unit
message.answer_tokens = usage.completion_tokens
message.answer_unit_price = usage.completion_unit_price
message.answer_price_unit = usage.completion_price_unit
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.metadata.usage = usage
else:
self._task_state.metadata.usage = LLMUsage.empty_usage()
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state

View File

@ -144,7 +144,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query or "",
query=query,
memory=memory,
)
@ -172,7 +172,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query or "",
query=query,
memory=memory,
)

View File

@ -79,7 +79,7 @@ class AppRunner:
prompt_template_entity: PromptTemplateEntity,
inputs: Mapping[str, str],
files: Sequence["File"],
query: str | None = None,
query: str = "",
context: str | None = None,
memory: TokenBufferMemory | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
@ -105,7 +105,7 @@ class AppRunner:
app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query or "",
query=query,
files=files,
context=context,
memory=memory,

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent,
@ -37,6 +37,7 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.trigger_manager import TriggerManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import (
NodeType,
@ -51,7 +52,7 @@ from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from services.variable_truncator import VariableTruncator
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
NodeExecutionId = NewType("NodeExecutionId", str)
@ -70,6 +71,8 @@ class _NodeSnapshot:
class WorkflowResponseConverter:
_truncator: BaseTruncator
def __init__(
self,
*,
@ -81,7 +84,13 @@ class WorkflowResponseConverter:
self._user = user
self._system_variables = system_variables
self._workflow_inputs = self._prepare_workflow_inputs()
self._truncator = VariableTruncator.default()
# Disable truncation for SERVICE_API calls to keep backward compatibility.
if application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
self._truncator = DummyVariableTruncator()
else:
self._truncator = VariableTruncator.default()
self._node_snapshots: dict[NodeExecutionId, _NodeSnapshot] = {}
self._workflow_execution_id: str | None = None
self._workflow_started_at: datetime | None = None
@ -295,6 +304,11 @@ class WorkflowResponseConverter:
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
self._application_generate_entity.app_config.tenant_id
)
elif event.node_type == NodeType.TRIGGER_PLUGIN:
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
)
return response

View File

@ -190,7 +190,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
conversation_id=conversation.id,
inputs=application_generate_entity.inputs,
query=application_generate_entity.query or "",
query=application_generate_entity.query,
message="",
message_tokens=0,
message_unit_price=0,

View File

@ -41,18 +41,14 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.datasource_provider_service import DatasourceProviderService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
@ -248,34 +244,7 @@ class PipelineGenerator(BaseAppGenerator):
)
if rag_pipeline_invoke_entities:
# store the rag_pipeline_invoke_entities to object storage
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
name = "rag_pipeline_invoke_entities.json"
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
if redis_client.get(tenant_pipeline_task_key):
# Add to waiting queue using List operations (lpush)
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
else:
# Set flag and execute task
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
else:
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay()
# return batch, dataset, documents
return {
"batch": batch,

View File

@ -27,6 +27,7 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -38,10 +39,16 @@ from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTrigger
from models.enums import WorkflowRunTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@staticmethod
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
@overload
def generate(
self,
@ -53,7 +60,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
) -> Generator[Mapping | str, None, None]: ...
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Generator[Mapping[str, Any] | str, None, None]: ...
@overload
def generate(
@ -66,6 +76,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Mapping[str, Any]: ...
@overload
@ -79,7 +92,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate(
self,
@ -91,7 +107,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files
@ -126,17 +145,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
# for trigger debug run, not prepare user inputs
if self._should_prepare_user_inputs(args):
inputs = self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(
user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
inputs=inputs,
files=list(system_files),
user_id=user.id,
stream=streaming,
@ -155,7 +177,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
if invoke_from == InvokeFrom.DEBUGGER:
if triggered_from is not None:
# Use explicitly provided triggered_from (for async triggers)
workflow_triggered_from = triggered_from
elif invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
@ -182,8 +207,16 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
)
def resume(self, *, workflow_run_id: str) -> None:
"""
@TBD
"""
pass
def _generate(
self,
*,
@ -196,6 +229,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -231,8 +266,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
"queue_manager": queue_manager,
"context": context,
"variable_loader": variable_loader,
"root_node_id": root_node_id,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": graph_engine_layers,
},
)
@ -426,6 +463,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> None:
"""
Generate worker in a new thread.
@ -469,6 +508,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
system_user_id=system_user_id,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
)
try:

View File

@ -1,5 +1,6 @@
import logging
import time
from collections.abc import Sequence
from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -8,6 +9,7 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -16,6 +18,7 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
from models.workflow import Workflow
@ -35,17 +38,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
root_node_id: str | None = None,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
):
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
graph_engine_layers=graph_engine_layers,
)
self.application_generate_entity = application_generate_entity
self._workflow = workflow
self._sys_user_id = system_user_id
self._root_node_id = root_node_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
@ -60,6 +67,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
@ -92,6 +100,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
root_node_id=self._root_node_id,
)
# RUN WORKFLOW
@ -135,6 +144,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
)
workflow_entry.graph_engine.layer(persistence_layer)
for layer in self._graph_engine_layers:
workflow_entry.graph_engine.layer(layer)
generator = workflow_entry.run()

View File

@ -1,5 +1,5 @@
import time
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -27,6 +27,7 @@ from core.app.entities.queue_entities import (
)
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunFailedEvent,
@ -69,10 +70,12 @@ class WorkflowBasedAppRunner:
queue_manager: AppQueueManager,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
app_id: str,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
):
self._queue_manager = queue_manager
self._variable_loader = variable_loader
self._app_id = app_id
self._graph_engine_layers = graph_engine_layers
def _init_graph(
self,
@ -81,6 +84,7 @@ class WorkflowBasedAppRunner:
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
root_node_id: str | None = None,
) -> Graph:
"""
Init graph
@ -114,7 +118,7 @@ class WorkflowBasedAppRunner:
)
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
if not graph:
raise ValueError("graph not found in workflow")

View File

@ -32,6 +32,10 @@ class InvokeFrom(StrEnum):
# https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
WEB_APP = "web-app"
# TRIGGER indicates that this invocation is from a trigger.
# this is used for plugin trigger and webhook trigger.
TRIGGER = "trigger"
# EXPLORE indicates that this invocation is from
# the workflow (or chatflow) explore page.
EXPLORE = "explore"
@ -40,6 +44,9 @@ class InvokeFrom(StrEnum):
DEBUGGER = "debugger"
PUBLISHED = "published"
# VALIDATION indicates that this invocation is from validation.
VALIDATION = "validation"
@classmethod
def value_of(cls, value: str):
"""
@ -65,6 +72,8 @@ class InvokeFrom(StrEnum):
return "dev"
elif self == InvokeFrom.EXPLORE:
return "explore_app"
elif self == InvokeFrom.TRIGGER:
return "trigger"
elif self == InvokeFrom.SERVICE_API:
return "api"
@ -104,6 +113,11 @@ class AppGenerateEntity(BaseModel):
inputs: Mapping[str, Any]
files: Sequence[File]
# Unique identifier of the user initiating the execution.
# This corresponds to `Account.id` for platform users or `EndUser.id` for end users.
#
# Note: The `user_id` field does not indicate whether the user is a platform user or an end user.
user_id: str
# extras
@ -129,7 +143,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
app_config: EasyUIBasedAppConfig = None # type: ignore
model_conf: ModelConfigWithCredentialsEntity
query: str | None = None
query: str = ""
# pydantic configs
model_config = ConfigDict(protected_namespaces=())

View File

@ -48,6 +48,9 @@ class WorkflowTaskState(TaskState):
"""
answer: str = ""
first_token_time: float | None = None
last_token_time: float | None = None
is_streaming_response: bool = False
class StreamEvent(StrEnum):

View File

@ -0,0 +1,133 @@
from typing import Annotated, Literal, Self, TypeAlias
from pydantic import BaseModel, Field
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunPausedEvent
from models.model import AppMode
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
# Wrapper types for `WorkflowAppGenerateEntity` and
# `AdvancedChatAppGenerateEntity`. These wrappers enable type discrimination
# and correct reconstruction of the entity field during (de)serialization.
class _WorkflowGenerateEntityWrapper(BaseModel):
type: Literal[AppMode.WORKFLOW] = AppMode.WORKFLOW
entity: WorkflowAppGenerateEntity
class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
type: Literal[AppMode.ADVANCED_CHAT] = AppMode.ADVANCED_CHAT
entity: AdvancedChatAppGenerateEntity
_GenerateEntityUnion: TypeAlias = Annotated[
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
Field(discriminator="type"),
]
class WorkflowResumptionContext(BaseModel):
"""WorkflowResumptionContext captures all state necessary for resumption."""
version: Literal["1"] = "1"
# Only workflow / chatflow could be paused.
generate_entity: _GenerateEntityUnion
serialized_graph_runtime_state: str
def dumps(self) -> str:
return self.model_dump_json()
@classmethod
def loads(cls, value: str) -> Self:
return cls.model_validate_json(value)
def get_generate_entity(self) -> WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity:
return self.generate_entity.entity
class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(
self,
session_factory: Engine | sessionmaker[Session],
generate_entity: WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity,
state_owner_user_id: str,
):
"""Create a PauseStatePersistenceLayer.
The `state_owner_user_id` is used when creating state file for pause.
It generally should id of the creator of workflow.
"""
if isinstance(session_factory, Engine):
session_factory = sessionmaker(session_factory)
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
def _get_repo(self) -> APIWorkflowRunRepository:
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
def on_graph_start(self) -> None:
"""
Called when graph execution starts.
This is called after the engine has been initialized but before any nodes
are executed. Layers can use this to set up resources or log start information.
"""
pass
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
This method receives all events generated during graph execution, including:
- Graph lifecycle events (start, success, failure)
- Node execution events (start, success, failure, retry)
- Stream events for response nodes
- Container events (iteration, loop)
Args:
event: The event emitted by the engine
"""
if not isinstance(event, GraphRunPausedEvent):
return
assert self.graph_runtime_state is not None
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
else:
entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity)
state = WorkflowResumptionContext(
serialized_graph_runtime_state=self.graph_runtime_state.dumps(),
generate_entity=entity_wrapper,
)
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
assert workflow_run_id is not None
repo = self._get_repo()
repo.create_workflow_pause(
workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id,
state=state.dumps(),
)
def on_graph_end(self, error: Exception | None) -> None:
"""
Called when graph execution ends.
This is called after all nodes have been executed or when execution is
aborted. Layers can use this to clean up resources or log final state.
Args:
error: The exception that caused execution to fail, or None if successful
"""
pass

View File

@ -0,0 +1,21 @@
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunPausedEvent
class SuspendLayer(GraphEngineLayer):
""" """
def on_graph_start(self):
pass
def on_event(self, event: GraphEngineEvent):
"""
Handle the paused event, stash runtime state into storage and wait for resume.
"""
if isinstance(event, GraphRunPausedEvent):
pass
def on_graph_end(self, error: Exception | None):
""" """
pass

View File

@ -0,0 +1,88 @@
import logging
import uuid
from typing import ClassVar
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
from core.workflow.graph_engine.entities.commands import CommandType, GraphEngineCommand
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
logger = logging.getLogger(__name__)
class TimeSliceLayer(GraphEngineLayer):
"""
CFS plan scheduler to control the timeslice of the workflow.
"""
scheduler: ClassVar[BackgroundScheduler] = BackgroundScheduler()
def __init__(self, cfs_plan_scheduler: CFSPlanScheduler) -> None:
"""
CFS plan scheduler allows to control the timeslice of the workflow.
"""
if not TimeSliceLayer.scheduler.running:
TimeSliceLayer.scheduler.start()
super().__init__()
self.cfs_plan_scheduler = cfs_plan_scheduler
self.stopped = False
self.schedule_id = ""
def _checker_job(self, schedule_id: str):
"""
Check if the workflow need to be suspended.
"""
try:
if self.stopped:
self.scheduler.remove_job(schedule_id)
return
if self.cfs_plan_scheduler.can_schedule() == SchedulerCommand.RESOURCE_LIMIT_REACHED:
# remove the job
self.scheduler.remove_job(schedule_id)
if not self.command_channel:
logger.exception("No command channel to stop the workflow")
return
# send command to pause the workflow
self.command_channel.send_command(
GraphEngineCommand(
command_type=CommandType.PAUSE,
payload={
"reason": SchedulerCommand.RESOURCE_LIMIT_REACHED,
},
)
)
except Exception:
logger.exception("scheduler error during check if the workflow need to be suspended")
def on_graph_start(self):
"""
Start timer to check if the workflow need to be suspended.
"""
if self.cfs_plan_scheduler.plan.schedule_strategy == WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice:
self.schedule_id = uuid.uuid4().hex
self.scheduler.add_job(
lambda: self._checker_job(self.schedule_id),
"interval",
seconds=self.cfs_plan_scheduler.plan.granularity,
id=self.schedule_id,
)
def on_event(self, event: GraphEngineEvent):
pass
def on_graph_end(self, error: Exception | None) -> None:
self.stopped = True
# remove the scheduler
if self.schedule_id:
self.scheduler.remove_job(self.schedule_id)

View File

@ -0,0 +1,88 @@
import logging
from datetime import UTC, datetime
from typing import Any, ClassVar
from pydantic import TypeAdapter
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
from models.enums import WorkflowTriggerStatus
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity
logger = logging.getLogger(__name__)
class TriggerPostLayer(GraphEngineLayer):
"""
Trigger post layer.
"""
_STATUS_MAP: ClassVar[dict[type[GraphEngineEvent], WorkflowTriggerStatus]] = {
GraphRunSucceededEvent: WorkflowTriggerStatus.SUCCEEDED,
GraphRunFailedEvent: WorkflowTriggerStatus.FAILED,
GraphRunPausedEvent: WorkflowTriggerStatus.PAUSED,
}
def __init__(
self,
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
start_time: datetime,
trigger_log_id: str,
session_maker: sessionmaker[Session],
):
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
self.session_maker = session_maker
def on_graph_start(self):
pass
def on_event(self, event: GraphEngineEvent):
"""
Update trigger log with success or failure.
"""
if isinstance(event, tuple(self._STATUS_MAP.keys())):
with self.session_maker() as session:
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = repo.get_by_id(self.trigger_log_id)
if not trigger_log:
logger.exception("Trigger log not found: %s", self.trigger_log_id)
return
# Calculate elapsed time
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
# Extract relevant data from result
if not self.graph_runtime_state:
logger.exception("Graph runtime state is not set")
return
outputs = self.graph_runtime_state.outputs
# BASICLY, workflow_execution_id is the same as workflow_run_id
workflow_run_id = self.graph_runtime_state.system_variable.workflow_execution_id
assert workflow_run_id, "Workflow run id is not set"
total_tokens = self.graph_runtime_state.total_tokens
# Update trigger log with success
trigger_log.status = self._STATUS_MAP[type(event)]
trigger_log.workflow_run_id = workflow_run_id
trigger_log.outputs = TypeAdapter(dict[str, Any]).dump_json(outputs).decode()
if trigger_log.elapsed_time is None:
trigger_log.elapsed_time = elapsed_time
else:
trigger_log.elapsed_time += elapsed_time
trigger_log.total_tokens = total_tokens
trigger_log.finished_at = datetime.now(UTC)
repo.update(trigger_log)
session.commit()
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@ -121,7 +121,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)

View File

@ -140,7 +140,27 @@ class MessageCycleManager:
if not self._application_generate_entity.app_config.additional_features:
raise ValueError("Additional features not found")
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata.retriever_resources = event.retriever_resources
merged_resources = [r for r in self._task_state.metadata.retriever_resources or [] if r]
existing_ids = {(r.dataset_id, r.document_id) for r in merged_resources if r.dataset_id and r.document_id}
# Add new unique resources from the event
for resource in event.retriever_resources or []:
if not resource:
continue
is_duplicate = (
resource.dataset_id
and resource.document_id
and (resource.dataset_id, resource.document_id) in existing_ids
)
if not is_duplicate:
merged_resources.append(resource)
for i, resource in enumerate(merged_resources, 1):
resource.position = i
self._task_state.metadata.retriever_resources = merged_resources
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
"""

View File

@ -0,0 +1,15 @@
from collections.abc import Sequence
from dataclasses import dataclass
@dataclass
class DocumentTask:
"""Document task entity for document indexing operations.
This class represents a document indexing task that can be queued
and processed by the document indexing system.
"""
tenant_id: str
dataset_id: str
document_ids: Sequence[str]

View File

@ -0,0 +1,329 @@
import json
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from pydantic import BaseModel
from configs import dify_config
from core.entities.provider_entities import BasicProviderConfig
from core.file import helpers as file_helpers
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
if TYPE_CHECKING:
from models.tools import MCPToolProvider
# Constants
CLIENT_NAME = "Dify"
CLIENT_URI = "https://github.com/langgenius/dify"
DEFAULT_TOKEN_TYPE = "Bearer"
DEFAULT_EXPIRES_IN = 3600
MASK_CHAR = "*"
MIN_UNMASK_LENGTH = 6
class MCPSupportGrantType(StrEnum):
"""The supported grant types for MCP"""
AUTHORIZATION_CODE = "authorization_code"
CLIENT_CREDENTIALS = "client_credentials"
REFRESH_TOKEN = "refresh_token"
class MCPAuthentication(BaseModel):
client_id: str
client_secret: str | None = None
class MCPConfiguration(BaseModel):
timeout: float = 30
sse_read_timeout: float = 300
class MCPProviderEntity(BaseModel):
"""MCP Provider domain entity for business logic operations"""
# Basic identification
id: str
provider_id: str # server_identifier
name: str
tenant_id: str
user_id: str
# Server connection info
server_url: str # encrypted URL
headers: dict[str, str] # encrypted headers
timeout: float
sse_read_timeout: float
# Authentication related
authed: bool
credentials: dict[str, Any] # encrypted credentials
code_verifier: str | None = None # for OAuth
# Tools and display info
tools: list[dict[str, Any]] # parsed tools list
icon: str | dict[str, str] # parsed icon
# Timestamps
created_at: datetime
updated_at: datetime
@classmethod
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
"""Create entity from database model with decryption"""
return cls(
id=db_provider.id,
provider_id=db_provider.server_identifier,
name=db_provider.name,
tenant_id=db_provider.tenant_id,
user_id=db_provider.user_id,
server_url=db_provider.server_url,
headers=db_provider.headers,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
authed=db_provider.authed,
credentials=db_provider.credentials,
tools=db_provider.tool_dict,
icon=db_provider.icon or "",
created_at=db_provider.created_at,
updated_at=db_provider.updated_at,
)
@property
def redirect_url(self) -> str:
"""OAuth redirect URL"""
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
# Get grant type from credentials
credentials = self.decrypt_credentials()
# Try to get grant_type from different locations
grant_type = credentials.get("grant_type", MCPSupportGrantType.AUTHORIZATION_CODE)
# For nested structure, check if client_information has grant_types
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
client_info = credentials["client_information"]
# If grant_types is specified in client_information, use it to determine grant_type
if "grant_types" in client_info and isinstance(client_info["grant_types"], list):
if "client_credentials" in client_info["grant_types"]:
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS
elif "authorization_code" in client_info["grant_types"]:
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE
# Configure based on grant type
is_client_credentials = grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS
grant_types = ["refresh_token"]
grant_types.append("client_credentials" if is_client_credentials else "authorization_code")
response_types = [] if is_client_credentials else ["code"]
redirect_uris = [] if is_client_credentials else [self.redirect_url]
return OAuthClientMetadata(
redirect_uris=redirect_uris,
token_endpoint_auth_method="none",
grant_types=grant_types,
response_types=response_types,
client_name=CLIENT_NAME,
client_uri=CLIENT_URI,
)
@property
def provider_icon(self) -> dict[str, str] | str:
"""Get provider icon, handling both dict and string formats"""
if isinstance(self.icon, dict):
return self.icon
try:
return json.loads(self.icon)
except (json.JSONDecodeError, TypeError):
# If not JSON, assume it's a file path
return file_helpers.get_signed_file_url(self.icon)
def to_api_response(self, user_name: str | None = None, include_sensitive: bool = True) -> dict[str, Any]:
"""Convert to API response format
Args:
user_name: User name to display
include_sensitive: If False, skip expensive decryption operations (for list view optimization)
"""
response = {
"id": self.id,
"author": user_name or "Anonymous",
"name": self.name,
"icon": self.provider_icon,
"type": ToolProviderType.MCP.value,
"is_team_authorization": self.authed,
"server_url": self.masked_server_url(),
"server_identifier": self.provider_id,
"updated_at": int(self.updated_at.timestamp()),
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
}
# Add configuration
response["configuration"] = {
"timeout": str(self.timeout),
"sse_read_timeout": str(self.sse_read_timeout),
}
# Skip expensive operations when sensitive data is not needed (e.g., list view)
if not include_sensitive:
response["masked_headers"] = {}
response["is_dynamic_registration"] = True
else:
# Add masked headers
response["masked_headers"] = self.masked_headers()
# Add authentication info if available
masked_creds = self.masked_credentials()
if masked_creds:
response["authentication"] = masked_creds
response["is_dynamic_registration"] = self.credentials.get("client_information", {}).get(
"is_dynamic_registration", True
)
return response
def retrieve_client_information(self) -> OAuthClientInformation | None:
"""OAuth client information if available"""
credentials = self.decrypt_credentials()
if not credentials:
return None
# Check if we have nested client_information structure
if "client_information" not in credentials:
return None
client_info_data = credentials["client_information"]
if isinstance(client_info_data, dict):
if "encrypted_client_secret" in client_info_data:
client_info_data["client_secret"] = encrypter.decrypt_token(
self.tenant_id, client_info_data["encrypted_client_secret"]
)
return OAuthClientInformation.model_validate(client_info_data)
return None
def retrieve_tokens(self) -> OAuthTokens | None:
"""OAuth tokens if available"""
if not self.credentials:
return None
credentials = self.decrypt_credentials()
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
refresh_token=credentials.get("refresh_token", ""),
)
def masked_server_url(self) -> str:
"""Masked server URL for display"""
parsed = urlparse(self.decrypt_server_url())
if parsed.path and parsed.path != "/":
masked = parsed._replace(path="/******")
return masked.geturl()
return parsed.geturl()
def _mask_value(self, value: str) -> str:
"""Mask a sensitive value for display"""
if len(value) > MIN_UNMASK_LENGTH:
return value[:2] + MASK_CHAR * (len(value) - 4) + value[-2:]
else:
return MASK_CHAR * len(value)
def masked_headers(self) -> dict[str, str]:
"""Masked headers for display"""
return {key: self._mask_value(value) for key, value in self.decrypt_headers().items()}
def masked_credentials(self) -> dict[str, str]:
"""Masked credentials for display"""
credentials = self.decrypt_credentials()
if not credentials:
return {}
masked = {}
if "client_information" not in credentials or not isinstance(credentials["client_information"], dict):
return {}
client_info = credentials["client_information"]
# Mask sensitive fields from nested structure
if client_info.get("client_id"):
masked["client_id"] = self._mask_value(client_info["client_id"])
if client_info.get("encrypted_client_secret"):
masked["client_secret"] = self._mask_value(
encrypter.decrypt_token(self.tenant_id, client_info["encrypted_client_secret"])
)
if client_info.get("client_secret"):
masked["client_secret"] = self._mask_value(client_info["client_secret"])
return masked
def decrypt_server_url(self) -> str:
"""Decrypt server URL"""
return encrypter.decrypt_token(self.tenant_id, self.server_url)
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
"""Generic method to decrypt dictionary fields"""
from core.tools.utils.encryption import create_provider_encrypter
if not data:
return {}
# Only decrypt fields that are actually encrypted
# For nested structures, client_information is not encrypted as a whole
encrypted_fields = []
for key, value in data.items():
# Skip nested objects - they are not encrypted
if isinstance(value, dict):
continue
# Only process string values that might be encrypted
if isinstance(value, str) and value:
encrypted_fields.append(key)
if not encrypted_fields:
return data
# Create dynamic config only for encrypted fields
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
# Decrypt only the encrypted fields
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
# Merge decrypted data with original data (preserving non-encrypted fields)
result = data.copy()
result.update(decrypted_data)
return result
def decrypt_headers(self) -> dict[str, Any]:
"""Decrypt headers"""
return self._decrypt_dict(self.headers)
def decrypt_credentials(self) -> dict[str, Any]:
"""Decrypt credentials"""
return self._decrypt_dict(self.credentials)
def decrypt_authentication(self) -> dict[str, Any]:
"""Decrypt authentication"""
# Option 1: if headers is provided, use it and don't need to get token
headers = self.decrypt_headers()
# Option 2: Add OAuth token if authed and no headers provided
if not self.headers and self.authed:
token = self.retrieve_tokens()
if token:
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
return headers

View File

@ -14,6 +14,7 @@ class CommonParameterType(StrEnum):
APP_SELECTOR = "app-selector"
MODEL_SELECTOR = "model-selector"
TOOLS_SELECTOR = "array[tools]"
CHECKBOX = "checkbox"
ANY = auto()
# Dynamic select parameter

View File

@ -1533,6 +1533,9 @@ class ProviderConfiguration(BaseModel):
# Return composite sort key: (model_type value, model position index)
return (model.model_type.value, position_index)
# Deduplicate
provider_models = list({(m.model, m.model_type, m.fetch_from): m for m in provider_models}.values())
# Sort using the composite sort key
return sorted(provider_models, key=get_sort_key)

View File

@ -107,7 +107,7 @@ class CustomModelConfiguration(BaseModel):
model: str
model_type: ModelType
credentials: dict | None = None
credentials: dict | None
current_credential_id: str | None = None
current_credential_name: str | None = None
available_model_credentials: list[CredentialConfiguration] = []
@ -207,6 +207,7 @@ class ProviderConfig(BasicProviderConfig):
required: bool = False
default: Union[int, str, float, bool] | None = None
options: list[Option] | None = None
multiple: bool | None = False
label: I18nObject | None = None
help: I18nObject | None = None
url: str | None = None

View File

@ -74,6 +74,10 @@ class File(BaseModel):
storage_key: str | None = None,
dify_model_identity: str | None = FILE_MODEL_IDENTITY,
url: str | None = None,
# Legacy compatibility fields - explicitly handle known extra fields
tool_file_id: str | None = None,
upload_file_id: str | None = None,
datasource_file_id: str | None = None,
):
super().__init__(
id=id,

View File

@ -6,10 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class NodeJsTemplateTransformer(TemplateTransformer):
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(
f"""
// declare main function
{cls._code_placeholder}
runner_script = dedent(f""" {cls._code_placeholder}
// decode and prepare input object
var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8'))
@ -21,6 +18,5 @@ class NodeJsTemplateTransformer(TemplateTransformer):
var output_json = JSON.stringify(output_obj)
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
console.log(result)
"""
)
""")
return runner_script

View File

@ -6,9 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class Python3TemplateTransformer(TemplateTransformer):
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
# declare main function
{cls._code_placeholder}
runner_script = dedent(f""" {cls._code_placeholder}
import json
from base64 import b64decode

Some files were not shown because too many files have changed in this diff Show More