merge main

This commit is contained in:
Charles Yao 2025-11-21 09:12:46 -06:00
commit 74fc026fe0
571 changed files with 32706 additions and 8378 deletions

View File

@ -6,11 +6,10 @@ cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
source /home/vscode/.bashrc

View File

@ -29,7 +29,7 @@ trim_trailing_whitespace = false
# Matches multiple files with brace expansion notation
# Set default charset
[*.{js,tsx}]
[*.{js,jsx,ts,tsx,mjs}]
indent_style = space
indent_size = 2

View File

@ -62,7 +62,7 @@ jobs:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db
db_postgres
redis
sandbox
ssrf_proxy

View File

@ -28,6 +28,11 @@ jobs:
# Format code
uv run ruff format ..
- name: count migration progress
run: |
cd api
./cnt_base.sh
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all

View File

@ -8,7 +8,7 @@ concurrency:
cancel-in-progress: true
jobs:
db-migration-test:
db-migration-test-postgres:
runs-on: ubuntu-latest
steps:
@ -45,7 +45,7 @@ jobs:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db
db_postgres
redis
- name: Prepare configs
@ -57,3 +57,60 @@ jobs:
env:
DEBUG: true
run: uv run --directory api flask upgrade-db
db-migration-test-mysql:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
python-version: "3.12"
cache-dependency-glob: api/uv.lock
- name: Install dependencies
run: uv sync --project api
- name: Ensure Offline migration are supported
run: |
# upgrade
uv run --directory api flask db upgrade 'base:head' --sql
# downgrade
uv run --directory api flask db downgrade 'head:base' --sql
- name: Prepare middleware env for MySQL
run: |
cd docker
cp middleware.env.example middleware.env
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db_mysql
redis
- name: Prepare configs for MySQL
run: |
cd api
cp .env.example .env
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' .env
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
- name: Run DB Migration
env:
DEBUG: true
run: uv run --directory api flask upgrade-db

View File

@ -51,13 +51,13 @@ jobs:
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Vector Store (TiDB)
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: docker/tidb/docker-compose.yaml
services: |
tidb
tiflash
# - name: Set up Vector Store (TiDB)
# uses: hoverkraft-tech/compose-action@v2.0.2
# with:
# compose-file: docker/tidb/docker-compose.yaml
# services: |
# tidb
# tiflash
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
uses: hoverkraft-tech/compose-action@v2.0.2
@ -83,8 +83,8 @@ jobs:
ls -lah .
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- name: Check VDB Ready (TiDB)
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh

2
.gitignore vendored
View File

@ -186,6 +186,8 @@ docker/volumes/couchbase/*
docker/volumes/oceanbase/*
docker/volumes/plugin_daemon/*
docker/volumes/matrixone/*
docker/volumes/mysql/*
docker/volumes/seekdb/*
!docker/volumes/oceanbase/init.d
docker/nginx/conf.d/default.conf

View File

@ -37,7 +37,7 @@
"-c",
"1",
"-Q",
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
"--loglevel",
"INFO"
],

View File

@ -70,6 +70,11 @@ type-check:
@uv run --directory api --dev basedpyright
@echo "✅ Type check complete"
test:
@echo "🧪 Running backend unit tests..."
@uv run --project api --dev dev/pytest/pytest_unit_tests.sh
@echo "✅ Tests complete"
# Build Docker images
build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
@ -119,6 +124,7 @@ help:
@echo " make check - Check code with ruff"
@echo " make lint - Format and fix code with ruff"
@echo " make type-check - Run type checking with basedpyright"
@echo " make test - Run backend unit tests"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"
@ -128,4 +134,4 @@ help:
@echo " make build-push-all - Build and push all Docker images"
# Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test

View File

@ -72,12 +72,15 @@ REDIS_CLUSTERS_PASSWORD=
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis
# PostgreSQL database configuration
# Database configuration
DB_TYPE=postgresql
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
DB_HOST=localhost
DB_PORT=5432
DB_DATABASE=dify
SQLALCHEMY_POOL_PRE_PING=true
SQLALCHEMY_POOL_TIMEOUT=30
@ -159,12 +162,11 @@ SUPABASE_URL=your-server-url
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Set COOKIE_DOMAIN when the console frontend and API are on different subdomains.
# Provide the registrable domain (e.g. example.com); leading dots are optional.
# When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the sites top-level domain (e.g., `example.com`). Leading dots are optional.
COOKIE_DOMAIN=
# Vector database configuration
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index
@ -175,6 +177,17 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=difyai123456
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
OCEANBASE_FULLTEXT_PARSER=ik
SEEKDB_MEMORY_LIMIT=2G
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
QDRANT_URL=http://localhost:6333
QDRANT_API_KEY=difyai123456
@ -340,15 +353,6 @@ LINDORM_PASSWORD=admin
LINDORM_USING_UGC=True
LINDORM_QUERY_TIMEOUT=1
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=difyai123456
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
# AlibabaCloud MySQL Vector configuration
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
ALIBABACLOUD_MYSQL_PORT=3306

View File

@ -15,8 +15,8 @@
```bash
cd ../docker
cp middleware.env.example middleware.env
# change the profile to other vector database if you are not using weaviate
docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d
# change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
cd ../api
```
@ -26,6 +26,10 @@
cp .env.example .env
```
> [!IMPORTANT]
>
> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the sites top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
1. Generate a `SECRET_KEY` in the `.env` file.
bash for Linux
@ -80,7 +84,7 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
```
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:

View File

@ -18,6 +18,7 @@ def create_flask_app_with_configs() -> DifyApp:
"""
dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump())
dify_app.config["RESTX_INCLUDE_ALL_MODELS"] = True
# add before request hook
@dify_app.before_request

7
api/cnt_base.sh Executable file
View File

@ -0,0 +1,7 @@
#!/bin/bash
set -euxo pipefail
for pattern in "Base" "TypeBase"; do
printf "%s " "$pattern"
grep "($pattern):" -r --include='*.py' --exclude-dir=".venv" --exclude-dir="tests" . | wc -l
done

View File

@ -77,10 +77,6 @@ class AppExecutionConfig(BaseSettings):
description="Maximum number of concurrent active requests per app (0 for unlimited)",
default=0,
)
APP_DAILY_RATE_LIMIT: NonNegativeInt = Field(
description="Maximum number of requests per app per day",
default=5000,
)
class CodeExecutionSandboxConfig(BaseSettings):
@ -1086,7 +1082,7 @@ class CeleryScheduleTasksConfig(BaseSettings):
)
TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field(
description="Proactive credential refresh threshold in seconds",
default=180,
default=60 * 60,
)
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
description="Proactive subscription refresh threshold in seconds",

View File

@ -105,6 +105,12 @@ class KeywordStoreConfig(BaseSettings):
class DatabaseConfig(BaseSettings):
# Database type selector
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
description="Database type to use. OceanBase is MySQL-compatible.",
default="postgresql",
)
DB_HOST: str = Field(
description="Hostname or IP address of the database server.",
default="localhost",
@ -140,10 +146,10 @@ class DatabaseConfig(BaseSettings):
default="",
)
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
description="Database URI scheme for SQLAlchemy connection.",
default="postgresql",
)
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
return "postgresql" if self.DB_TYPE == "postgresql" else "mysql+pymysql"
@computed_field # type: ignore[prop-decorator]
@property
@ -204,15 +210,15 @@ class DatabaseConfig(BaseSettings):
# Parse DB_EXTRAS for 'options'
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
options = db_extras_dict.get("options", "")
# Always include timezone
timezone_opt = "-c timezone=UTC"
if options:
# Merge user options and timezone
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
connect_args = {}
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
timezone_opt = "-c timezone=UTC"
if options:
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
return {
"pool_size": self.SQLALCHEMY_POOL_SIZE,

File diff suppressed because one or more lines are too long

View File

@ -104,14 +104,11 @@ class BaseApiKeyResource(Resource):
resource_model: type | None = None
resource_id_field: str | None = None
def delete(self, resource_id, api_key_id):
def delete(self, resource_id: str, api_key_id: str):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
api_key_id = str(api_key_id)
current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()

View File

@ -5,18 +5,20 @@ from controllers.console.wraps import account_initialization_required, setup_req
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
parser = (
reqparse.RequestParser()
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
)
@console_ns.route("/app/prompt-templates")
class AdvancedPromptTemplateList(Resource):
@api.doc("get_advanced_prompt_templates")
@api.doc(description="Get advanced prompt templates based on app mode and model configuration")
@api.expect(
api.parser()
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
.add_argument("has_context", type=str, default="true", location="args", help="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
)
@api.expect(parser)
@api.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
)
@ -25,13 +27,6 @@ class AdvancedPromptTemplateList(Resource):
@login_required
@account_initialization_required
def get(self):
parser = (
reqparse.RequestParser()
.add_argument("app_mode", type=str, required=True, location="args")
.add_argument("model_mode", type=str, required=True, location="args")
.add_argument("has_context", type=str, required=False, default="true", location="args")
.add_argument("model_name", type=str, required=True, location="args")
)
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args)

View File

@ -8,17 +8,19 @@ from libs.login import login_required
from models.model import AppMode
from services.agent_service import AgentService
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
)
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
class AgentLogApi(Resource):
@api.doc("get_agent_logs")
@api.doc(description="Get agent execution logs for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("message_id", type=str, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID")
)
@api.expect(parser)
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
@api.response(400, "Invalid request parameters")
@setup_required
@ -27,12 +29,6 @@ class AgentLogApi(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model):
"""Get agent logs"""
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args")
.add_argument("conversation_id", type=uuid_value, required=True, location="args")
)
args = parser.parse_args()
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])

View File

@ -251,6 +251,13 @@ class AnnotationExportApi(Resource):
return response, 200
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource):
@api.doc("update_delete_annotation")
@ -259,6 +266,7 @@ class AnnotationUpdateDeleteApi(Resource):
@api.response(200, "Annotation updated successfully", annotation_fields)
@api.response(204, "Annotation deleted successfully")
@api.response(403, "Insufficient permissions")
@api.expect(parser)
@setup_required
@login_required
@account_initialization_required
@ -268,11 +276,6 @@ class AnnotationUpdateDeleteApi(Resource):
def post(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation

View File

@ -3,7 +3,7 @@ import uuid
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort
from werkzeug.exceptions import BadRequest, abort
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
@ -12,14 +12,16 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
edit_permission_required,
enterprise_license_required,
is_admin_or_owner_required,
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.workflow.enums import NodeType
from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import App
from models import App, Workflow
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@ -106,6 +108,35 @@ class AppListApi(Resource):
if str(app.id) in res:
app.access_mode = res[str(app.id)].access_mode
workflow_capable_app_ids = [
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
]
draft_trigger_app_ids: set[str] = set()
if workflow_capable_app_ids:
draft_workflows = (
db.session.execute(
select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids),
)
)
.scalars()
.all()
)
trigger_node_types = {
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
for _, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
return marshal(app_pagination, app_pagination_fields), 200
@api.doc("create_app")
@ -220,10 +251,8 @@ class AppApi(Resource):
args = parser.parse_args()
app_service = AppService()
# Construct ArgsDict from parsed arguments
from services.app_service import AppService as AppServiceType
args_dict: AppServiceType.ArgsDict = {
args_dict: AppService.ArgsDict = {
"name": args["name"],
"description": args.get("description", ""),
"icon_type": args.get("icon_type", ""),
@ -353,12 +382,15 @@ class AppExportApi(Resource):
}
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@api.doc("check_app_name")
@api.doc(description="Check if app name is available")
@api.doc(params={"app_id": "Application ID"})
@api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check"))
@api.expect(parser)
@api.response(200, "Name availability checked")
@setup_required
@login_required
@ -367,7 +399,6 @@ class AppNameApi(Resource):
@marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model):
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
@ -455,15 +486,11 @@ class AppApiStatus(Resource):
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
args = parser.parse_args()

View File

@ -1,6 +1,7 @@
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
@ -18,9 +19,23 @@ from services.feature_service import FeatureService
from .. import console_ns
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("app_id", type=str, location="json")
)
@console_ns.route("/apps/imports")
class AppImportApi(Resource):
@api.expect(parser)
@setup_required
@login_required
@account_initialization_required
@ -30,18 +45,6 @@ class AppImportApi(Resource):
def post(self):
# Check user role first
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("app_id", type=str, location="json")
)
args = parser.parse_args()
# Create service with session

View File

@ -3,11 +3,10 @@ from typing import cast
from flask import request
from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
@ -48,15 +47,12 @@ class ModelConfigResource(Resource):
@api.response(404, "App not found")
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model):
"""Modify app model config"""
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_tenant_id,

View File

@ -1,10 +1,15 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
is_admin_or_owner_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now
@ -76,17 +81,13 @@ class AppSite(Resource):
@api.response(404, "App not found")
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_fields)
def post(self, app_model):
args = parse_app_site_args()
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be editor, admin, or owner
if not current_user.has_edit_permission:
raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise NotFound
@ -130,16 +131,12 @@ class AppSiteAccessTokenReset(Resource):
@api.response(404, "App or site not found")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:

View File

@ -10,9 +10,9 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.helper import DatetimeString, convert_datetime_to_date
from libs.login import current_account_with_tenant, login_required
from models import AppMode, Message
from models import AppMode
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
@ -44,8 +44,9 @@ class DailyMessageStatistic(Resource):
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(*) AS message_count
FROM
messages
@ -80,16 +81,19 @@ WHERE
return jsonify({"data": response_data})
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
class DailyConversationStatistic(Resource):
@api.doc("get_daily_conversation_statistics")
@api.doc(description="Get daily conversation statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Daily conversation statistics retrieved successfully",
@ -102,12 +106,18 @@ class DailyConversationStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(DISTINCT conversation_id) AS conversation_count
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
try:
@ -115,30 +125,21 @@ class DailyConversationStatistic(Resource):
except ValueError as e:
abort(400, description=str(e))
stmt = (
sa.select(
sa.func.date(
sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
).label("date"),
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
)
.select_from(Message)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
)
if start_datetime_utc:
stmt = stmt.where(Message.created_at >= start_datetime_utc)
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if end_datetime_utc:
stmt = stmt.where(Message.created_at < end_datetime_utc)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
stmt = stmt.group_by("date").order_by("date")
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(stmt, {"tz": account.timezone})
for row in rs:
response_data.append({"date": str(row.date), "conversation_count": row.conversation_count})
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
return jsonify({"data": response_data})
@ -148,11 +149,7 @@ class DailyTerminalsStatistic(Resource):
@api.doc("get_daily_terminals_statistics")
@api.doc(description="Get daily terminal/end-user statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Daily terminal statistics retrieved successfully",
@ -165,15 +162,11 @@ class DailyTerminalsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
FROM
messages
@ -213,11 +206,7 @@ class DailyTokenCostStatistic(Resource):
@api.doc("get_daily_token_cost_statistics")
@api.doc(description="Get daily token cost statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Daily token cost statistics retrieved successfully",
@ -230,15 +219,11 @@ class DailyTokenCostStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
SUM(total_price) AS total_price
FROM
@ -281,11 +266,7 @@ class AverageSessionInteractionStatistic(Resource):
@api.doc("get_average_session_interaction_statistics")
@api.doc(description="Get average session interaction statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Average session interaction statistics retrieved successfully",
@ -298,15 +279,11 @@ class AverageSessionInteractionStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("c.created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
AVG(subquery.message_count) AS interactions
FROM
(
@ -365,11 +342,7 @@ class UserSatisfactionRateStatistic(Resource):
@api.doc("get_user_satisfaction_rate_statistics")
@api.doc(description="Get user satisfaction rate statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"User satisfaction rate statistics retrieved successfully",
@ -382,15 +355,11 @@ class UserSatisfactionRateStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("m.created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(m.id) AS message_count,
COUNT(mf.id) AS feedback_count
FROM
@ -439,11 +408,7 @@ class AverageResponseTimeStatistic(Resource):
@api.doc("get_average_response_time_statistics")
@api.doc(description="Get average response time statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Average response time statistics retrieved successfully",
@ -456,15 +421,11 @@ class AverageResponseTimeStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
AVG(provider_response_latency) AS latency
FROM
messages
@ -504,11 +465,7 @@ class TokensPerSecondStatistic(Resource):
@api.doc("get_tokens_per_second_statistics")
@api.doc(description="Get tokens per second statistics for an application")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.expect(parser)
@api.response(
200,
"Tokens per second statistics retrieved successfully",
@ -520,16 +477,11 @@ class TokensPerSecondStatistic(Resource):
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
CASE
WHEN SUM(provider_response_latency) = 0 THEN 0
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))

View File

@ -586,6 +586,13 @@ class DraftWorkflowNodeRunApi(Resource):
return workflow_node_execution
parser_publish = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
class PublishedWorkflowApi(Resource):
@api.doc("get_published_workflow")
@ -610,6 +617,7 @@ class PublishedWorkflowApi(Resource):
# return workflow, if not found, return None
return workflow
@api.expect(parser_publish)
@setup_required
@login_required
@account_initialization_required
@ -620,12 +628,8 @@ class PublishedWorkflowApi(Resource):
Publish workflow
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
args = parser.parse_args()
args = parser_publish.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
@ -680,6 +684,9 @@ class DefaultBlockConfigsApi(Resource):
return workflow_service.get_default_block_configs()
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultBlockConfigApi(Resource):
@api.doc("get_default_block_config")
@ -687,6 +694,7 @@ class DefaultBlockConfigApi(Resource):
@api.doc(params={"app_id": "Application ID", "block_type": "Block type"})
@api.response(200, "Default block configuration retrieved successfully")
@api.response(404, "Block type not found")
@api.expect(parser_block)
@setup_required
@login_required
@account_initialization_required
@ -696,8 +704,7 @@ class DefaultBlockConfigApi(Resource):
"""
Get default block config
"""
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
args = parser.parse_args()
args = parser_block.parse_args()
q = args.get("q")
@ -713,8 +720,18 @@ class DefaultBlockConfigApi(Resource):
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
parser_convert = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
class ConvertToWorkflowApi(Resource):
@api.expect(parser_convert)
@api.doc("convert_to_workflow")
@api.doc(description="Convert application to workflow mode")
@api.doc(params={"app_id": "Application ID"})
@ -735,14 +752,7 @@ class ConvertToWorkflowApi(Resource):
current_user, _ = current_account_with_tenant()
if request.data:
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_convert.parse_args()
else:
args = {}
@ -756,8 +766,18 @@ class ConvertToWorkflowApi(Resource):
}
parser_workflows = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource):
@api.expect(parser_workflows)
@api.doc("get_all_published_workflows")
@api.doc(description="Get all published workflows for an application")
@api.doc(params={"app_id": "Application ID"})
@ -774,16 +794,9 @@ class PublishedAllWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
args = parser.parse_args()
page = int(args.get("page", 1))
limit = int(args.get("limit", 10))
args = parser_workflows.parse_args()
page = args["page"]
limit = args["limit"]
user_id = args.get("user_id")
named_only = args.get("named_only", False)
@ -970,8 +983,9 @@ class DraftWorkflowTriggerRunApi(Resource):
Poll for trigger events and execute full workflow when event arrives
"""
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
parser = reqparse.RequestParser().add_argument(
"node_id", type=str, required=True, location="json", nullable=False
)
args = parser.parse_args()
node_id = args["node_id"]
workflow_service = WorkflowService()
@ -1123,8 +1137,9 @@ class DraftWorkflowTriggerRunAllApi(Resource):
"""
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("node_ids", type=list, required=True, location="json", nullable=False)
parser = reqparse.RequestParser().add_argument(
"node_ids", type=list, required=True, location="json", nullable=False
)
args = parser.parse_args()
node_ids = args["node_ids"]
workflow_service = WorkflowService()

View File

@ -1,17 +1,18 @@
import logging
from typing import NoReturn
from collections.abc import Callable
from functools import wraps
from typing import NoReturn, ParamSpec, TypeVar
from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns
from controllers.console.app.error import (
DraftWorkflowNotExist,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.file import helpers as file_helpers
from core.variables.segment_group import SegmentGroup
@ -21,8 +22,8 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIAB
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import Account, App, AppMode
from libs.login import login_required
from models import App, AppMode
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService
@ -140,8 +141,11 @@ _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
}
P = ParamSpec("P")
R = TypeVar("R")
def _api_prerequisite(f):
def _api_prerequisite(f: Callable[P, R]):
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -155,11 +159,10 @@ def _api_prerequisite(f):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def wrapper(*args, **kwargs):
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs):
return f(*args, **kwargs)
return wrapper
@ -167,6 +170,7 @@ def _api_prerequisite(f):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
class WorkflowVariableCollectionApi(Resource):
@api.expect(_create_pagination_parser())
@api.doc("get_workflow_variables")
@api.doc(description="Get draft workflow variables")
@api.doc(params={"app_id": "Application ID"})

View File

@ -30,23 +30,25 @@ def _parse_workflow_run_list_args():
Returns:
Parsed arguments containing last_id, limit, status, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
)
return parser.parse_args()
@ -58,28 +60,30 @@ def _parse_workflow_run_count_args():
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
parser = (
reqparse.RequestParser()
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
)
return parser.parse_args()

View File

@ -3,12 +3,12 @@ import logging
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
@ -29,8 +29,7 @@ class WebhookTriggerApi(Resource):
@marshal_with(webhook_trigger_fields)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, help="Node ID is required")
args = parser.parse_args()
node_id = str(args["node_id"])
@ -95,19 +94,19 @@ class AppTriggerEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_fields)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
parser = reqparse.RequestParser()
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
parser = (
reqparse.RequestParser()
.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
)
args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if not current_user.has_edit_permission:
raise Forbidden()
trigger_id = args["trigger_id"]

View File

@ -1,8 +1,8 @@
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
@ -39,12 +39,10 @@ class ApiKeyAuthDataSourceBinding(Resource):
@setup_required
@login_required
@account_initialization_required
@is_admin_or_owner_required
def post(self):
# The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("category", type=str, required=True, nullable=False, location="json")
@ -65,12 +63,10 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@setup_required
@login_required
@account_initialization_required
@is_admin_or_owner_required
def delete(self, binding_id):
# The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)

View File

@ -3,11 +3,11 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api, console_ns
from libs.login import current_account_with_tenant, login_required
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required
@ -42,11 +42,9 @@ class OAuthDataSource(Resource):
)
@api.response(400, "Invalid provider")
@api.response(403, "Admin privileges required")
@is_admin_or_owner_required
def get(self, provider: str):
# The role of the current user in the table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)

View File

@ -1,6 +1,9 @@
from flask_restx import Resource, reqparse
import base64
from controllers.console import console_ns
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import BadRequest
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
@ -41,3 +44,37 @@ class Invoices(Resource):
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_invoices(current_user.email, current_tenant_id)
@console_ns.route("/billing/partners/<string:partner_key>/tenants")
class PartnerTenants(Resource):
@api.doc("sync_partner_tenants_bindings")
@api.doc(description="Sync partner tenants bindings")
@api.doc(params={"partner_key": "Partner key"})
@api.expect(
api.model(
"SyncPartnerTenantsBindingsRequest",
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
)
)
@api.response(200, "Tenants synced to partner successfully")
@api.response(400, "Invalid partner information")
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def put(self, partner_key: str):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
args = parser.parse_args()
try:
click_id = args["click_id"]
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
except Exception:
raise BadRequest("Invalid partner_key")
if not click_id or not decoded_partner_key or not current_user.id:
raise BadRequest("Invalid partner information")
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)

View File

@ -15,6 +15,7 @@ from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
enterprise_license_required,
is_admin_or_owner_required,
setup_required,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
@ -753,13 +754,11 @@ class DatasetApiKeyApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@marshal_with(api_key_fields)
def post(self):
# The role of the current user in the ta table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
current_key_count = (
db.session.query(ApiToken)
@ -794,15 +793,11 @@ class DatasetApiDeleteApi(Resource):
@api.response(204, "API key deleted successfully")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, api_key_id):
current_user, current_tenant_id = current_account_with_tenant()
_, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
key = (
db.session.query(ApiToken)
.where(

View File

@ -162,6 +162,7 @@ class DatasetDocumentListApi(Resource):
"keyword": "Search keyword",
"sort": "Sort order (default: -created_at)",
"fetch": "Fetch full details (default: false)",
"status": "Filter documents by display status",
}
)
@api.response(200, "Documents retrieved successfully")
@ -175,6 +176,7 @@ class DatasetDocumentListApi(Resource):
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
sort = request.args.get("sort", default="-created_at", type=str)
status = request.args.get("status", default=None, type=str)
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
fetch_val = request.args.get("fetch", default="false")
@ -203,6 +205,9 @@ class DatasetDocumentListApi(Resource):
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
if status:
query = DocumentService.apply_display_status_filter(query, status)
if search:
search = f"%{search}%"
query = query.where(Document.name.like(search))

View File

@ -5,7 +5,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import api, console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
@ -200,12 +200,10 @@ class ExternalDatasetCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")

View File

@ -3,7 +3,7 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
@ -121,8 +121,16 @@ class DatasourceOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_datasource = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource):
@api.expect(parser_datasource)
@setup_required
@login_required
@account_initialization_required
@ -130,14 +138,7 @@ class DatasourceAuth(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument(
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_datasource.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -168,8 +169,14 @@ class DatasourceAuth(Resource):
return {"result": datasources}, 200
parser_datasource_delete = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource):
@api.expect(parser_datasource_delete)
@setup_required
@login_required
@account_initialization_required
@ -181,10 +188,7 @@ class DatasourceAuthDeleteApi(Resource):
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
parser = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_datasource_delete.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_tenant_id,
@ -195,8 +199,17 @@ class DatasourceAuthDeleteApi(Resource):
return {"result": "success"}, 200
parser_datasource_update = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource):
@api.expect(parser_datasource_update)
@setup_required
@login_required
@account_initialization_required
@ -205,13 +218,7 @@ class DatasourceAuthUpdateApi(Resource):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_datasource_update.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
@ -251,8 +258,16 @@ class DatasourceHardCodeAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200
parser_datasource_custom = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource):
@api.expect(parser_datasource_custom)
@setup_required
@login_required
@account_initialization_required
@ -260,12 +275,7 @@ class DatasourceAuthOauthCustomClient(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_datasource_custom.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
@ -291,8 +301,12 @@ class DatasourceAuthOauthCustomClient(Resource):
return {"result": "success"}, 200
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource):
@api.expect(parser_default)
@setup_required
@login_required
@account_initialization_required
@ -300,8 +314,7 @@ class DatasourceAuthDefaultApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
args = parser_default.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
@ -312,8 +325,16 @@ class DatasourceAuthDefaultApi(Resource):
return {"result": "success"}, 200
parser_update_name = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource):
@api.expect(parser_update_name)
@setup_required
@login_required
@account_initialization_required
@ -321,12 +342,7 @@ class DatasourceUpdateProviderNameApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_update_name.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(

View File

@ -1,7 +1,7 @@
from flask_restx import ( # type: ignore
Resource, # type: ignore
reqparse,
)
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns
@ -12,17 +12,21 @@ from models import Account
from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class Parser(BaseModel):
inputs: dict
datasource_type: str
credential_id: str | None = None
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource):
@api.expect(parser)
@api.expect(console_ns.models[Parser.__name__], validate=True)
@setup_required
@login_required
@account_initialization_required
@ -34,15 +38,10 @@ class DataSourceContentPreviewApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
args = parser.parse_args()
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
args = Parser.model_validate(api.payload)
inputs = args.inputs
datasource_type = args.datasource_type
rag_pipeline_service = RagPipelineService()
preview_content = rag_pipeline_service.run_datasource_node_preview(
pipeline=pipeline,
@ -51,6 +50,6 @@ class DataSourceContentPreviewApi(Resource):
account=current_user,
datasource_type=datasource_type,
is_published=True,
credential_id=args.get("credential_id"),
credential_id=args.credential_id,
)
return preview_content, 200

View File

@ -1,11 +1,11 @@
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
@ -21,12 +21,11 @@ class RagPipelineImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields)
def post(self):
# Check user role first
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
@ -71,12 +70,10 @@ class RagPipelineImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields)
def post(self, import_id):
current_user, _ = current_account_with_tenant()
# Check user role first
if not current_user.has_edit_permission:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
@ -98,12 +95,9 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@login_required
@get_rag_pipeline
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
@ -117,12 +111,9 @@ class RagPipelineExportApi(Resource):
@login_required
@get_rag_pipeline
@account_initialization_required
@edit_permission_required
def get(self, pipeline: Pipeline):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Add include_secret params
# Add include_secret params
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
args = parser.parse_args()

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.app.error import (
ConversationCompletedError,
DraftWorkflowNotExist,
@ -148,8 +148,12 @@ class DraftRagPipelineApi(Resource):
}
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource):
@api.expect(parser_run)
@setup_required
@login_required
@account_initialization_required
@ -162,8 +166,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
args = parser_run.parse_args()
try:
response = PipelineGenerateService.generate_single_iteration(
@ -184,9 +187,11 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource):
@api.expect(parser_run)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
"""
@ -194,11 +199,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
args = parser_run.parse_args()
try:
response = PipelineGenerateService.generate_single_loop(
@ -217,11 +219,22 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
raise InternalServerError()
parser_draft_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource):
@api.expect(parser_draft_run)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
def post(self, pipeline: Pipeline):
"""
@ -229,17 +242,8 @@ class DraftRagPipelineRunApi(Resource):
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_draft_run.parse_args()
try:
response = PipelineGenerateService.generate(
@ -255,11 +259,25 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
parser_published_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource):
@api.expect(parser_published_run)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
def post(self, pipeline: Pipeline):
"""
@ -267,20 +285,8 @@ class PublishedRagPipelineRunApi(Resource):
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_published_run.parse_args()
streaming = args["response_mode"] == "streaming"
@ -381,11 +387,21 @@ class PublishedRagPipelineRunApi(Resource):
#
# return result
#
parser_rag_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@api.expect(parser_rag_run)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
"""
@ -393,16 +409,8 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_rag_run.parse_args()
inputs = args.get("inputs")
if inputs is None:
@ -429,8 +437,10 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource):
@api.expect(parser_rag_run)
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
@ -439,16 +449,8 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_rag_run.parse_args()
inputs = args.get("inputs")
if inputs is None:
@ -473,10 +475,17 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
)
parser_run_api = reqparse.RequestParser().add_argument(
"inputs", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource):
@api.expect(parser_run_api)
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_fields)
@ -486,13 +495,8 @@ class RagPipelineDraftNodeRunApi(Resource):
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"inputs", type=dict, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_run_api.parse_args()
inputs = args.get("inputs")
if inputs == None:
@ -513,6 +517,7 @@ class RagPipelineDraftNodeRunApi(Resource):
class RagPipelineTaskStopApi(Resource):
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, task_id: str):
@ -521,8 +526,6 @@ class RagPipelineTaskStopApi(Resource):
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
@ -534,6 +537,7 @@ class PublishedRagPipelineApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
@marshal_with(workflow_fields)
def get(self, pipeline: Pipeline):
@ -541,9 +545,6 @@ class PublishedRagPipelineApi(Resource):
Get published pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
if not pipeline.is_published:
return None
# fetch published workflow by pipeline
@ -556,6 +557,7 @@ class PublishedRagPipelineApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
def post(self, pipeline: Pipeline):
"""
@ -563,9 +565,6 @@ class PublishedRagPipelineApi(Resource):
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session:
pipeline = session.merge(pipeline)
@ -592,38 +591,33 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
def get(self, pipeline: Pipeline):
"""
Get default block config
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs
rag_pipeline_service = RagPipelineService()
return rag_pipeline_service.get_default_block_configs()
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource):
@api.expect(parser_default)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
def get(self, pipeline: Pipeline, block_type: str):
"""
Get default block config
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
args = parser.parse_args()
args = parser_default.parse_args()
q = args.get("q")
@ -639,11 +633,22 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
parser_wf = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
class PublishedAllRagPipelineApi(Resource):
@api.expect(parser_wf)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
@marshal_with(workflow_pagination_fields)
def get(self, pipeline: Pipeline):
@ -651,19 +656,10 @@ class PublishedAllRagPipelineApi(Resource):
Get published workflows
"""
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
args = parser.parse_args()
page = int(args.get("page", 1))
limit = int(args.get("limit", 10))
args = parser_wf.parse_args()
page = args["page"]
limit = args["limit"]
user_id = args.get("user_id")
named_only = args.get("named_only", False)
@ -691,11 +687,20 @@ class PublishedAllRagPipelineApi(Resource):
}
parser_wf_id = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource):
@api.expect(parser_wf_id)
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
@marshal_with(workflow_fields)
def patch(self, pipeline: Pipeline, workflow_id: str):
@ -704,22 +709,14 @@ class RagPipelineByIdApi(Resource):
"""
# Check permission
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_wf_id.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
args = parser.parse_args()
# Prepare update data
update_data = {}
@ -752,8 +749,12 @@ class RagPipelineByIdApi(Resource):
return workflow
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource):
@api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -763,8 +764,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
"""
Get second step parameters of rag pipeline
"""
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
@ -777,6 +777,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource):
@api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -786,8 +787,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
"""
Get first step parameters of rag pipeline
"""
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
@ -800,6 +800,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource):
@api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -809,8 +810,7 @@ class DraftRagPipelineFirstStepApi(Resource):
"""
Get first step parameters of rag pipeline
"""
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
@ -823,6 +823,7 @@ class DraftRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource):
@api.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -832,8 +833,7 @@ class DraftRagPipelineSecondStepApi(Resource):
"""
Get second step parameters of rag pipeline
"""
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
@ -845,8 +845,16 @@ class DraftRagPipelineSecondStepApi(Resource):
}
parser_wf_run = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource):
@api.expect(parser_wf_run)
@setup_required
@login_required
@account_initialization_required
@ -856,12 +864,7 @@ class RagPipelineWorkflowRunListApi(Resource):
"""
Get workflow run list
"""
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
args = parser_wf_run.parse_args()
rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
@ -961,8 +964,18 @@ class RagPipelineTransformApi(Resource):
return result
parser_var = (
reqparse.RequestParser()
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info", type=dict, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("start_node_title", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource):
@api.expect(parser_var)
@setup_required
@login_required
@account_initialization_required
@ -974,14 +987,7 @@ class RagPipelineDatasourceVariableApi(Resource):
Set datasource variables
"""
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info", type=dict, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("start_node_title", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_var.parse_args()
rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.set_datasource_variables(

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
from libs.login import current_user, login_required
@ -35,15 +35,18 @@ recommended_app_list_fields = {
}
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args")
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@api.expect(parser_apps)
@login_required
@account_initialization_required
@marshal_with(recommended_app_list_fields)
def get(self):
# language args
parser = reqparse.RequestParser().add_argument("language", type=str, location="args")
args = parser.parse_args()
args = parser_apps.parse_args()
language = args.get("language")
if language and language in languages:

View File

@ -10,6 +10,7 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.console import api
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
@ -36,12 +37,15 @@ class RemoteFileInfoApi(Resource):
}
parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
@console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource):
@api.expect(parser_upload)
@marshal_with(file_fields_with_signed_url)
def post(self):
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
args = parser.parse_args()
args = parser_upload.parse_args()
url = args["url"]

View File

@ -49,6 +49,7 @@ class SetupApi(Resource):
"email": fields.String(required=True, description="Admin email address"),
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
"password": fields.String(required=True, description="Admin password"),
"language": fields.String(required=False, description="Admin language"),
},
)
)

View File

@ -2,8 +2,8 @@ from flask import request
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required
from models.model import Tag
@ -16,6 +16,19 @@ def _validate_name(name):
return name
parser_tags = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=_validate_name,
)
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tags")
class TagListApi(Resource):
@setup_required
@ -30,6 +43,7 @@ class TagListApi(Resource):
return tags, 200
@api.expect(parser_tags)
@setup_required
@login_required
@account_initialization_required
@ -39,20 +53,7 @@ class TagListApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=_validate_name,
)
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
)
args = parser.parse_args()
args = parser_tags.parse_args()
tag = TagService.save_tags(args)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@ -60,8 +61,14 @@ class TagListApi(Resource):
return response, 200
parser_tag_id = reqparse.RequestParser().add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
@console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource):
@api.expect(parser_tag_id)
@setup_required
@login_required
@account_initialization_required
@ -72,10 +79,7 @@ class TagUpdateDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
args = parser.parse_args()
args = parser_tag_id.parse_args()
tag = TagService.update_tags(args, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@ -87,20 +91,26 @@ class TagUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, tag_id):
current_user, _ = current_account_with_tenant()
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
TagService.delete_tag(tag_id)
return 204
parser_create = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
@api.expect(parser_create)
@setup_required
@login_required
@account_initialization_required
@ -110,26 +120,23 @@ class TagBindingCreateApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument(
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
)
.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
)
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
)
args = parser.parse_args()
args = parser_create.parse_args()
TagService.save_tag_binding(args)
return {"result": "success"}, 200
parser_remove = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
@api.expect(parser_remove)
@setup_required
@login_required
@account_initialization_required
@ -139,15 +146,7 @@ class TagBindingDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
)
args = parser.parse_args()
args = parser_remove.parse_args()
TagService.delete_tag_binding(args)
return {"result": "success"}, 200

View File

@ -11,16 +11,16 @@ from . import api, console_ns
logger = logging.getLogger(__name__)
parser = reqparse.RequestParser().add_argument(
"current_version", type=str, required=True, location="args", help="Current application version"
)
@console_ns.route("/version")
class VersionApi(Resource):
@api.doc("check_version_update")
@api.doc(description="Check for application version updates")
@api.expect(
api.parser().add_argument(
"current_version", type=str, required=True, location="args", help="Current application version"
)
)
@api.expect(parser)
@api.response(
200,
"Success",
@ -37,7 +37,6 @@ class VersionApi(Resource):
)
def get(self):
"""Check for application version updates"""
parser = reqparse.RequestParser().add_argument("current_version", type=str, required=True, location="args")
args = parser.parse_args()
check_update_url = dify_config.CHECK_UPDATE_URL

View File

@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
EmailChangeLimitError,
@ -43,8 +43,19 @@ from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
def _init_parser():
parser = reqparse.RequestParser()
if dify_config.EDITION == "CLOUD":
parser.add_argument("invitation_code", type=str, location="json")
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
"timezone", type=timezone, required=True, location="json"
)
return parser
@console_ns.route("/account/init")
class AccountInitApi(Resource):
@api.expect(_init_parser())
@setup_required
@login_required
def post(self):
@ -53,14 +64,7 @@ class AccountInitApi(Resource):
if account.status == "active":
raise AccountAlreadyInitedError()
parser = reqparse.RequestParser()
if dify_config.EDITION == "CLOUD":
parser.add_argument("invitation_code", type=str, location="json")
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
"timezone", type=timezone, required=True, location="json"
)
args = parser.parse_args()
args = _init_parser().parse_args()
if dify_config.EDITION == "CLOUD":
if not args["invitation_code"]:
@ -106,16 +110,19 @@ class AccountProfileApi(Resource):
return current_user
parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
@console_ns.route("/account/name")
class AccountNameApi(Resource):
@api.expect(parser_name)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_name.parse_args()
# Validate account name length
if len(args["name"]) < 3 or len(args["name"]) > 30:
@ -126,68 +133,80 @@ class AccountNameApi(Resource):
return updated_account
parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@api.expect(parser_avatar)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_avatar.parse_args()
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
return updated_account
parser_interface = reqparse.RequestParser().add_argument(
"interface_language", type=supported_language, required=True, location="json"
)
@console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource):
@api.expect(parser_interface)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"interface_language", type=supported_language, required=True, location="json"
)
args = parser.parse_args()
args = parser_interface.parse_args()
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
return updated_account
parser_theme = reqparse.RequestParser().add_argument(
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
)
@console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource):
@api.expect(parser_theme)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
)
args = parser.parse_args()
args = parser_theme.parse_args()
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
return updated_account
parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
@console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource):
@api.expect(parser_timezone)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_timezone.parse_args()
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
if args["timezone"] not in pytz.all_timezones:
@ -198,21 +217,24 @@ class AccountTimezoneApi(Resource):
return updated_account
parser_pw = (
reqparse.RequestParser()
.add_argument("password", type=str, required=False, location="json")
.add_argument("new_password", type=str, required=True, location="json")
.add_argument("repeat_new_password", type=str, required=True, location="json")
)
@console_ns.route("/account/password")
class AccountPasswordApi(Resource):
@api.expect(parser_pw)
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("password", type=str, required=False, location="json")
.add_argument("new_password", type=str, required=True, location="json")
.add_argument("repeat_new_password", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_pw.parse_args()
if args["new_password"] != args["repeat_new_password"]:
raise RepeatPasswordNotMatchError()
@ -294,20 +316,23 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token}
parser_delete = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
)
@console_ns.route("/account/delete")
class AccountDeleteApi(Resource):
@api.expect(parser_delete)
@setup_required
@login_required
@account_initialization_required
def post(self):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_delete.parse_args()
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
raise InvalidAccountDeletionCodeError()
@ -317,16 +342,19 @@ class AccountDeleteApi(Resource):
return {"result": "success"}
parser_feedback = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("feedback", type=str, required=True, location="json")
)
@console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource):
@api.expect(parser_feedback)
@setup_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("feedback", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_feedback.parse_args()
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
@ -351,6 +379,14 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email)
parser_edu = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("institution", type=str, required=True, location="json")
.add_argument("role", type=str, required=True, location="json")
)
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
@ -360,6 +396,7 @@ class EducationApi(Resource):
"allow_refresh": fields.Boolean,
}
@api.expect(parser_edu)
@setup_required
@login_required
@account_initialization_required
@ -368,13 +405,7 @@ class EducationApi(Resource):
def post(self):
account, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("institution", type=str, required=True, location="json")
.add_argument("role", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_edu.parse_args()
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
@ -394,6 +425,14 @@ class EducationApi(Resource):
return res
parser_autocomplete = (
reqparse.RequestParser()
.add_argument("keywords", type=str, required=True, location="args")
.add_argument("page", type=int, required=False, location="args", default=0)
.add_argument("limit", type=int, required=False, location="args", default=20)
)
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
@ -402,6 +441,7 @@ class EducationAutoCompleteApi(Resource):
"has_next": fields.Boolean,
}
@api.expect(parser_autocomplete)
@setup_required
@login_required
@account_initialization_required
@ -409,33 +449,30 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
parser = (
reqparse.RequestParser()
.add_argument("keywords", type=str, required=True, location="args")
.add_argument("page", type=int, required=False, location="args", default=0)
.add_argument("limit", type=int, required=False, location="args", default=20)
)
args = parser.parse_args()
args = parser_autocomplete.parse_args()
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
parser_change_email = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
.add_argument("phase", type=str, required=False, location="json")
.add_argument("token", type=str, required=False, location="json")
)
@console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource):
@api.expect(parser_change_email)
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
.add_argument("phase", type=str, required=False, location="json")
.add_argument("token", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_change_email.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@ -470,20 +507,23 @@ class ChangeEmailSendEmailApi(Resource):
return {"result": "success", "data": token}
parser_validity = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource):
@api.expect(parser_validity)
@enable_change_email
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_validity.parse_args()
user_email = args["email"]
@ -514,20 +554,23 @@ class ChangeEmailCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
parser_reset = (
reqparse.RequestParser()
.add_argument("new_email", type=email, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource):
@api.expect(parser_reset)
@enable_change_email
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("new_email", type=email, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_reset.parse_args()
if AccountService.is_account_in_freeze(args["new_email"]):
raise AccountInFreezeError()
@ -555,12 +598,15 @@ class ChangeEmailResetApi(Resource):
return updated_account
parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
@console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource):
@api.expect(parser_check)
@setup_required
def post(self):
parser = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
args = parser.parse_args()
args = parser_check.parse_args()
if AccountService.is_account_in_freeze(args["email"]):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args["email"]):

View File

@ -1,8 +1,7 @@
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import current_account_with_tenant, login_required
@ -31,11 +30,10 @@ class EndpointCreateApi(Resource):
@api.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
@ -168,6 +166,7 @@ class EndpointDeleteApi(Resource):
@api.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
@ -175,9 +174,6 @@ class EndpointDeleteApi(Resource):
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
if not user.is_admin_or_owner:
raise Forbidden()
endpoint_id = args["endpoint_id"]
return {
@ -207,6 +203,7 @@ class EndpointUpdateApi(Resource):
@api.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
@ -223,9 +220,6 @@ class EndpointUpdateApi(Resource):
settings = args["settings"]
name = args["name"]
if not user.is_admin_or_owner:
raise Forbidden()
return {
"success": EndpointService.update_endpoint(
tenant_id=tenant_id,
@ -252,6 +246,7 @@ class EndpointEnableApi(Resource):
@api.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
@ -261,9 +256,6 @@ class EndpointEnableApi(Resource):
endpoint_id = args["endpoint_id"]
if not user.is_admin_or_owner:
raise Forbidden()
return {
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}
@ -284,6 +276,7 @@ class EndpointDisableApi(Resource):
@api.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
@ -293,9 +286,6 @@ class EndpointDisableApi(Resource):
endpoint_id = args["endpoint_id"]
if not user.is_admin_or_owner:
raise Forbidden()
return {
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}

View File

@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse
import services
from configs import dify_config
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
EmailCodeError,
@ -48,22 +48,25 @@ class MemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
parser_invite = (
reqparse.RequestParser()
.add_argument("emails", type=list, required=True, location="json")
.add_argument("role", type=str, required=True, default="admin", location="json")
.add_argument("language", type=str, required=False, location="json")
)
@console_ns.route("/workspaces/current/members/invite-email")
class MemberInviteEmailApi(Resource):
"""Invite a new member by email."""
@api.expect(parser_invite)
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("members")
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("emails", type=list, required=True, location="json")
.add_argument("role", type=str, required=True, default="admin", location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = parser_invite.parse_args()
invitee_emails = args["emails"]
invitee_role = args["role"]
@ -143,16 +146,19 @@ class MemberCancelInviteApi(Resource):
}, 200
parser_update = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
class MemberUpdateRoleApi(Resource):
"""Update member role."""
@api.expect(parser_update)
@setup_required
@login_required
@account_initialization_required
def put(self, member_id):
parser = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_update.parse_args()
new_role = args["role"]
if not TenantAccountRole.is_valid_role(new_role):
@ -191,17 +197,20 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
parser_send = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email."""
@api.expect(parser_send)
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
parser = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
args = parser_send.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@ -229,19 +238,22 @@ class SendOwnerTransferEmailApi(Resource):
return {"result": "success", "data": token}
parser_owner = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource):
@api.expect(parser_owner)
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_owner.parse_args()
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
@ -276,17 +288,20 @@ class OwnerTransferCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
parser_owner_transfer = reqparse.RequestParser().add_argument(
"token", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
class OwnerTransfer(Resource):
@api.expect(parser_owner_transfer)
@setup_required
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self, member_id):
parser = reqparse.RequestParser().add_argument(
"token", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_owner_transfer.parse_args()
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()

View File

@ -2,10 +2,9 @@ import io
from flask import send_file
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
@ -14,9 +13,19 @@ from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
parser_model = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=False,
nullable=True,
choices=[mt.value for mt in ModelType],
location="args",
)
@console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource):
@api.expect(parser_model)
@setup_required
@login_required
@account_initialization_required
@ -24,15 +33,7 @@ class ModelProviderListApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
parser = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=False,
nullable=True,
choices=[mt.value for mt in ModelType],
location="args",
)
args = parser.parse_args()
args = parser_model.parse_args()
model_provider_service = ModelProviderService()
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
@ -40,8 +41,30 @@ class ModelProviderListApi(Resource):
return jsonable_encoder({"data": provider_list})
parser_cred = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
)
parser_post_cred = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_put_cred = (
reqparse.RequestParser()
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_delete_cred = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
class ModelProviderCredentialApi(Resource):
@api.expect(parser_cred)
@setup_required
@login_required
@account_initialization_required
@ -49,10 +72,7 @@ class ModelProviderCredentialApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
parser = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
)
args = parser.parse_args()
args = parser_cred.parse_args()
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_provider_credential(
@ -61,20 +81,14 @@ class ModelProviderCredentialApi(Resource):
return {"credentials": credentials}
@api.expect(parser_post_cred)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
args = parser_post_cred.parse_args()
model_provider_service = ModelProviderService()
@ -90,21 +104,15 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}, 201
@api.expect(parser_put_cred)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def put(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_put_cred.parse_args()
model_provider_service = ModelProviderService()
@ -121,17 +129,14 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}
@api.expect(parser_delete_cred)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
args = parser_delete_cred.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
@ -141,19 +146,21 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}, 204
parser_switch = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
class ModelProviderCredentialSwitchApi(Resource):
@api.expect(parser_switch)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
args = parser_switch.parse_args()
service = ModelProviderService()
service.switch_active_provider_credential(
@ -164,17 +171,20 @@ class ModelProviderCredentialSwitchApi(Resource):
return {"result": "success"}
parser_validate = reqparse.RequestParser().add_argument(
"credentials", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
class ModelProviderValidateApi(Resource):
@api.expect(parser_validate)
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"credentials", type=dict, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_validate.parse_args()
tenant_id = current_tenant_id
@ -218,27 +228,29 @@ class ModelProviderIconApi(Resource):
return send_file(io.BytesIO(icon), mimetype=mimetype)
parser_preferred = reqparse.RequestParser().add_argument(
"preferred_provider_type",
type=str,
required=True,
nullable=False,
choices=["system", "custom"],
location="json",
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource):
@api.expect(parser_preferred)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
parser = reqparse.RequestParser().add_argument(
"preferred_provider_type",
type=str,
required=True,
nullable=False,
choices=["system", "custom"],
location="json",
)
args = parser.parse_args()
args = parser_preferred.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.switch_preferred_provider(

View File

@ -1,10 +1,9 @@
import logging
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
@ -16,23 +15,29 @@ from services.model_provider_service import ModelProviderService
logger = logging.getLogger(__name__)
parser_get_default = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
parser_post_default = reqparse.RequestParser().add_argument(
"model_settings", type=list, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource):
@api.expect(parser_get_default)
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
args = parser.parse_args()
args = parser_get_default.parse_args()
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
@ -41,19 +46,15 @@ class DefaultModelApi(Resource):
return jsonable_encoder({"data": default_model_entity})
@api.expect(parser_post_default)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
current_user, tenant_id = current_account_with_tenant()
_, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser().add_argument(
"model_settings", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_post_default.parse_args()
model_provider_service = ModelProviderService()
model_settings = args["model_settings"]
for model_setting in model_settings:
@ -84,6 +85,35 @@ class DefaultModelApi(Resource):
return {"result": "success"}
parser_post_models = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
)
parser_delete_models = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
class ModelProviderModelApi(Resource):
@setup_required
@ -97,32 +127,15 @@ class ModelProviderModelApi(Resource):
return jsonable_encoder({"data": models})
@api.expect(parser_post_models)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
# To save the model's load balance configs
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
)
args = parser.parse_args()
_, tenant_id = current_account_with_tenant()
args = parser_post_models.parse_args()
if args.get("config_from", "") == "custom-model":
if not args.get("credential_id"):
@ -160,28 +173,15 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 200
@api.expect(parser_delete_models)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
_, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
args = parser.parse_args()
args = parser_delete_models.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_model(
@ -191,29 +191,76 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 204
parser_get_credentials = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="args")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
)
parser_post_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
parser_put_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_delete_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
class ModelProviderModelCredentialApi(Resource):
@api.expect(parser_get_credentials)
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="args")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
)
args = parser.parse_args()
args = parser_get_credentials.parse_args()
model_provider_service = ModelProviderService()
current_credential = model_provider_service.get_model_credential(
@ -257,30 +304,15 @@ class ModelProviderModelCredentialApi(Resource):
}
)
@api.expect(parser_post_cred)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
_, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_post_cred.parse_args()
model_provider_service = ModelProviderService()
@ -304,31 +336,14 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}, 201
@api.expect(parser_put_cred)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def put(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
args = parser_put_cred.parse_args()
model_provider_service = ModelProviderService()
@ -347,28 +362,14 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}
@api.expect(parser_delete_cred)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
args = parser_delete_cred.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential(
@ -382,30 +383,32 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}, 204
parser_switch = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource):
@api.expect(parser_switch)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
_, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_switch.parse_args()
service = ModelProviderService()
service.add_model_credential_to_model_list(
@ -418,29 +421,32 @@ class ModelProviderModelCredentialSwitchApi(Resource):
return {"result": "success"}
parser_model_enable_disable = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
@console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
)
class ModelProviderModelEnableApi(Resource):
@api.expect(parser_model_enable_disable)
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
args = parser.parse_args()
args = parser_model_enable_disable.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.enable_model(
@ -454,25 +460,14 @@ class ModelProviderModelEnableApi(Resource):
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
)
class ModelProviderModelDisableApi(Resource):
@api.expect(parser_model_enable_disable)
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
args = parser.parse_args()
args = parser_model_enable_disable.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.disable_model(
@ -482,28 +477,31 @@ class ModelProviderModelDisableApi(Resource):
return {"result": "success"}
parser_validate = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
class ModelProviderModelValidateApi(Resource):
@api.expect(parser_validate)
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_validate.parse_args()
model_provider_service = ModelProviderService()
@ -530,16 +528,19 @@ class ModelProviderModelValidateApi(Resource):
return response
parser_parameter = reqparse.RequestParser().add_argument(
"model", type=str, required=True, nullable=False, location="args"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource):
@api.expect(parser_parameter)
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
parser = reqparse.RequestParser().add_argument(
"model", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args()
args = parser_parameter.parse_args()
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()

View File

@ -5,9 +5,9 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import current_account_with_tenant, login_required
@ -37,19 +37,22 @@ class PluginDebuggingKeyApi(Resource):
raise ValueError(e)
parser_list = (
reqparse.RequestParser()
.add_argument("page", type=int, required=False, location="args", default=1)
.add_argument("page_size", type=int, required=False, location="args", default=256)
)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
@api.expect(parser_list)
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=False, location="args", default=1)
.add_argument("page_size", type=int, required=False, location="args", default=256)
)
args = parser.parse_args()
args = parser_list.parse_args()
try:
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
except PluginDaemonClientSideError as e:
@ -58,14 +61,17 @@ class PluginListApi(Resource):
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource):
@api.expect(parser_latest)
@setup_required
@login_required
@account_initialization_required
def post(self):
req = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
args = req.parse_args()
args = parser_latest.parse_args()
try:
versions = PluginService.list_latest_versions(args["plugin_ids"])
@ -75,16 +81,19 @@ class PluginListLatestVersionsApi(Resource):
return jsonable_encoder({"versions": versions})
parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource):
@api.expect(parser_ids)
@setup_required
@login_required
@account_initialization_required
def post(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
args = parser.parse_args()
args = parser_ids.parse_args()
try:
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
@ -94,16 +103,19 @@ class PluginListInstallationsFromIdsApi(Resource):
return jsonable_encoder({"plugins": plugins})
parser_icon = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="args")
.add_argument("filename", type=str, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource):
@api.expect(parser_icon)
@setup_required
def get(self):
req = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="args")
.add_argument("filename", type=str, required=True, location="args")
)
args = req.parse_args()
args = parser_icon.parse_args()
try:
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
@ -120,9 +132,11 @@ class PluginAssetApi(Resource):
@login_required
@account_initialization_required
def get(self):
req = reqparse.RequestParser()
req.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
req.add_argument("file_name", type=str, required=True, location="args")
req = (
reqparse.RequestParser()
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
.add_argument("file_name", type=str, required=True, location="args")
)
args = req.parse_args()
_, tenant_id = current_account_with_tenant()
@ -157,8 +171,17 @@ class PluginUploadFromPkgApi(Resource):
return jsonable_encoder(response)
parser_github = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource):
@api.expect(parser_github)
@setup_required
@login_required
@account_initialization_required
@ -166,13 +189,7 @@ class PluginUploadFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_github.parse_args()
try:
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
@ -206,19 +223,21 @@ class PluginUploadFromBundleApi(Resource):
return jsonable_encoder(response)
parser_pkg = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource):
@api.expect(parser_pkg)
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
args = parser.parse_args()
args = parser_pkg.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
@ -233,8 +252,18 @@ class PluginInstallFromPkgApi(Resource):
return jsonable_encoder(response)
parser_githubapi = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource):
@api.expect(parser_githubapi)
@setup_required
@login_required
@account_initialization_required
@ -242,14 +271,7 @@ class PluginInstallFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_githubapi.parse_args()
try:
response = PluginService.install_from_github(
@ -265,8 +287,14 @@ class PluginInstallFromGithubApi(Resource):
return jsonable_encoder(response)
parser_marketplace = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource):
@api.expect(parser_marketplace)
@setup_required
@login_required
@account_initialization_required
@ -274,10 +302,7 @@ class PluginInstallFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
args = parser.parse_args()
args = parser_marketplace.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
@ -292,19 +317,21 @@ class PluginInstallFromMarketplaceApi(Resource):
return jsonable_encoder(response)
parser_pkgapi = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource):
@api.expect(parser_pkgapi)
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
args = parser.parse_args()
args = parser_pkgapi.parse_args()
try:
return jsonable_encoder(
@ -319,8 +346,14 @@ class PluginFetchMarketplacePkgApi(Resource):
raise ValueError(e)
parser_fetch = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource):
@api.expect(parser_fetch)
@setup_required
@login_required
@account_initialization_required
@ -328,10 +361,7 @@ class PluginFetchManifestApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
args = parser.parse_args()
args = parser_fetch.parse_args()
try:
return jsonable_encoder(
@ -345,8 +375,16 @@ class PluginFetchManifestApi(Resource):
raise ValueError(e)
parser_tasks = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource):
@api.expect(parser_tasks)
@setup_required
@login_required
@account_initialization_required
@ -354,12 +392,7 @@ class PluginFetchInstallTasksApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
args = parser.parse_args()
args = parser_tasks.parse_args()
try:
return jsonable_encoder(
@ -429,8 +462,16 @@ class PluginDeleteInstallTaskItemApi(Resource):
raise ValueError(e)
parser_marketplace_api = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource):
@api.expect(parser_marketplace_api)
@setup_required
@login_required
@account_initialization_required
@ -438,12 +479,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_marketplace_api.parse_args()
try:
return jsonable_encoder(
@ -455,8 +491,19 @@ class PluginUpgradeFromMarketplaceApi(Resource):
raise ValueError(e)
parser_github_post = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource):
@api.expect(parser_github_post)
@setup_required
@login_required
@account_initialization_required
@ -464,15 +511,7 @@ class PluginUpgradeFromGithubApi(Resource):
def post(self):
_, tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
args = parser.parse_args()
args = parser_github_post.parse_args()
try:
return jsonable_encoder(
@ -489,15 +528,20 @@ class PluginUpgradeFromGithubApi(Resource):
raise ValueError(e)
parser_uninstall = reqparse.RequestParser().add_argument(
"plugin_installation_id", type=str, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource):
@api.expect(parser_uninstall)
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
req = reqparse.RequestParser().add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args()
args = parser_uninstall.parse_args()
_, tenant_id = current_account_with_tenant()
@ -507,8 +551,16 @@ class PluginUninstallApi(Resource):
raise ValueError(e)
parser_change_post = (
reqparse.RequestParser()
.add_argument("install_permission", type=str, required=True, location="json")
.add_argument("debug_permission", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource):
@api.expect(parser_change_post)
@setup_required
@login_required
@account_initialization_required
@ -518,12 +570,7 @@ class PluginChangePermissionApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
req = (
reqparse.RequestParser()
.add_argument("install_permission", type=str, required=True, location="json")
.add_argument("debug_permission", type=str, required=True, location="json")
)
args = req.parse_args()
args = parser_change_post.parse_args()
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
@ -558,29 +605,29 @@ class PluginFetchPermissionApi(Resource):
)
parser_dynamic = (
reqparse.RequestParser()
.add_argument("plugin_id", type=str, required=True, location="args")
.add_argument("provider", type=str, required=True, location="args")
.add_argument("action", type=str, required=True, location="args")
.add_argument("parameter", type=str, required=True, location="args")
.add_argument("credential_id", type=str, required=False, location="args")
.add_argument("provider_type", type=str, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource):
@api.expect(parser_dynamic)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self):
# check if the user is admin or owner
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
parser = (
reqparse.RequestParser()
.add_argument("plugin_id", type=str, required=True, location="args")
.add_argument("provider", type=str, required=True, location="args")
.add_argument("action", type=str, required=True, location="args")
.add_argument("parameter", type=str, required=True, location="args")
.add_argument("credential_id", type=str, required=False, location="args")
.add_argument("provider_type", type=str, required=True, location="args")
)
args = parser.parse_args()
args = parser_dynamic.parse_args()
try:
options = PluginParameterService.get_dynamic_select_options(
@ -599,8 +646,16 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
parser_change = (
reqparse.RequestParser()
.add_argument("permission", type=dict, required=True, location="json")
.add_argument("auto_upgrade", type=dict, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
@api.expect(parser_change)
@setup_required
@login_required
@account_initialization_required
@ -609,12 +664,7 @@ class PluginChangePreferencesApi(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
req = (
reqparse.RequestParser()
.add_argument("permission", type=dict, required=True, location="json")
.add_argument("auto_upgrade", type=dict, required=True, location="json")
)
args = req.parse_args()
args = parser_change.parse_args()
permission = args["permission"]
@ -694,8 +744,12 @@ class PluginFetchPreferencesApi(Resource):
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource):
@api.expect(parser_exclude)
@setup_required
@login_required
@account_initialization_required
@ -703,8 +757,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
# exclude one single plugin
_, tenant_id = current_account_with_tenant()
req = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
args = req.parse_args()
args = parser_exclude.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
@ -716,9 +769,11 @@ class PluginReadmeApi(Resource):
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
parser.add_argument("language", type=str, required=False, location="args")
parser = (
reqparse.RequestParser()
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
.add_argument("language", type=str, required=False, location="args")
)
args = parser.parse_args()
return jsonable_encoder(
{

View File

@ -10,10 +10,11 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
is_admin_or_owner_required,
setup_required,
)
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
@ -52,8 +53,19 @@ def is_valid_url(url: str) -> bool:
return False
parser_tool = reqparse.RequestParser().add_argument(
"type",
type=str,
choices=["builtin", "model", "api", "workflow", "mcp"],
required=False,
nullable=True,
location="args",
)
@console_ns.route("/workspaces/current/tool-providers")
class ToolProviderListApi(Resource):
@api.expect(parser_tool)
@setup_required
@login_required
@account_initialization_required
@ -62,15 +74,7 @@ class ToolProviderListApi(Resource):
user_id = user.id
req = reqparse.RequestParser().add_argument(
"type",
type=str,
choices=["builtin", "model", "api", "workflow", "mcp"],
required=False,
nullable=True,
location="args",
)
args = req.parse_args()
args = parser_tool.parse_args()
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
@ -102,20 +106,22 @@ class ToolBuiltinProviderInfoApi(Resource):
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
parser_delete = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
class ToolBuiltinProviderDeleteApi(Resource):
@api.expect(parser_delete)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
_, tenant_id = current_account_with_tenant()
req = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
args = req.parse_args()
args = parser_delete.parse_args()
return BuiltinToolManageService.delete_builtin_tool_provider(
tenant_id,
@ -124,8 +130,17 @@ class ToolBuiltinProviderDeleteApi(Resource):
)
parser_add = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
.add_argument("type", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
class ToolBuiltinProviderAddApi(Resource):
@api.expect(parser_add)
@setup_required
@login_required
@account_initialization_required
@ -134,13 +149,7 @@ class ToolBuiltinProviderAddApi(Resource):
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
.add_argument("type", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_add.parse_args()
if args["type"] not in CredentialType.values():
raise ValueError(f"Invalid credential type: {args['type']}")
@ -155,27 +164,26 @@ class ToolBuiltinProviderAddApi(Resource):
)
parser_update = (
reqparse.RequestParser()
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
class ToolBuiltinProviderUpdateApi(Resource):
@api.expect(parser_update)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_update.parse_args()
result = BuiltinToolManageService.update_builtin_tool_provider(
user_id=user_id,
@ -213,32 +221,32 @@ class ToolBuiltinProviderIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
parser_api_add = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/api/add")
class ToolApiProviderAddApi(Resource):
@api.expect(parser_api_add)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_api_add.parse_args()
return ApiToolManageService.create_api_tool_provider(
user_id,
@ -254,8 +262,12 @@ class ToolApiProviderAddApi(Resource):
)
parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
@console_ns.route("/workspaces/current/tool-provider/api/remote")
class ToolApiProviderGetRemoteSchemaApi(Resource):
@api.expect(parser_remote)
@setup_required
@login_required
@account_initialization_required
@ -264,9 +276,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
args = parser_remote.parse_args()
return ApiToolManageService.get_api_tool_provider_remote_schema(
user_id,
@ -275,8 +285,14 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
)
parser_tools = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
)
@console_ns.route("/workspaces/current/tool-provider/api/tools")
class ToolApiProviderListToolsApi(Resource):
@api.expect(parser_tools)
@setup_required
@login_required
@account_initialization_required
@ -285,11 +301,7 @@ class ToolApiProviderListToolsApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args()
args = parser_tools.parse_args()
return jsonable_encoder(
ApiToolManageService.list_api_tool_provider_tools(
@ -300,33 +312,33 @@ class ToolApiProviderListToolsApi(Resource):
)
parser_api_update = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/api/update")
class ToolApiProviderUpdateApi(Resource):
@api.expect(parser_api_update)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_api_update.parse_args()
return ApiToolManageService.update_api_tool_provider(
user_id,
@ -343,24 +355,24 @@ class ToolApiProviderUpdateApi(Resource):
)
parser_api_delete = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/api/delete")
class ToolApiProviderDeleteApi(Resource):
@api.expect(parser_api_delete)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_api_delete.parse_args()
return ApiToolManageService.delete_api_tool_provider(
user_id,
@ -369,8 +381,12 @@ class ToolApiProviderDeleteApi(Resource):
)
parser_get = reqparse.RequestParser().add_argument("provider", type=str, required=True, nullable=False, location="args")
@console_ns.route("/workspaces/current/tool-provider/api/get")
class ToolApiProviderGetApi(Resource):
@api.expect(parser_get)
@setup_required
@login_required
@account_initialization_required
@ -379,11 +395,7 @@ class ToolApiProviderGetApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args()
args = parser_get.parse_args()
return ApiToolManageService.get_api_tool_provider(
user_id,
@ -407,40 +419,44 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
)
parser_schema = reqparse.RequestParser().add_argument(
"schema", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/api/schema")
class ToolApiProviderSchemaApi(Resource):
@api.expect(parser_schema)
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser().add_argument(
"schema", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_schema.parse_args()
return ApiToolManageService.parser_api_schema(
schema=args["schema"],
)
parser_pre = (
reqparse.RequestParser()
.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
class ToolApiProviderPreviousTestApi(Resource):
@api.expect(parser_pre)
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
.add_argument("schema", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = parser_pre.parse_args()
_, current_tenant_id = current_account_with_tenant()
return ApiToolManageService.test_api_tool_preview(
current_tenant_id,
@ -453,32 +469,32 @@ class ToolApiProviderPreviousTestApi(Resource):
)
parser_create = (
reqparse.RequestParser()
.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
class ToolWorkflowProviderCreateApi(Resource):
@api.expect(parser_create)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
reqparser = (
reqparse.RequestParser()
.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
args = reqparser.parse_args()
args = parser_create.parse_args()
return WorkflowToolManageService.create_workflow_tool(
user_id=user_id,
@ -494,32 +510,31 @@ class ToolWorkflowProviderCreateApi(Resource):
)
parser_workflow_update = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
class ToolWorkflowProviderUpdateApi(Resource):
@api.expect(parser_workflow_update)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
reqparser = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
.add_argument("label", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
args = reqparser.parse_args()
args = parser_workflow_update.parse_args()
if not args["workflow_tool_id"]:
raise ValueError("incorrect workflow_tool_id")
@ -538,24 +553,24 @@ class ToolWorkflowProviderUpdateApi(Resource):
)
parser_workflow_delete = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
class ToolWorkflowProviderDeleteApi(Resource):
@api.expect(parser_workflow_delete)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
reqparser = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
)
args = reqparser.parse_args()
args = parser_workflow_delete.parse_args()
return WorkflowToolManageService.delete_workflow_tool(
user_id,
@ -564,8 +579,16 @@ class ToolWorkflowProviderDeleteApi(Resource):
)
parser_wf_get = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
)
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
class ToolWorkflowProviderGetApi(Resource):
@api.expect(parser_wf_get)
@setup_required
@login_required
@account_initialization_required
@ -574,13 +597,7 @@ class ToolWorkflowProviderGetApi(Resource):
user_id = user.id
parser = (
reqparse.RequestParser()
.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
)
args = parser.parse_args()
args = parser_wf_get.parse_args()
if args.get("workflow_tool_id"):
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
@ -600,8 +617,14 @@ class ToolWorkflowProviderGetApi(Resource):
return jsonable_encoder(tool)
parser_wf_tools = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
)
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
class ToolWorkflowProviderListToolApi(Resource):
@api.expect(parser_wf_tools)
@setup_required
@login_required
@account_initialization_required
@ -610,11 +633,7 @@ class ToolWorkflowProviderListToolApi(Resource):
user_id = user.id
parser = reqparse.RequestParser().add_argument(
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
)
args = parser.parse_args()
args = parser_wf_tools.parse_args()
return jsonable_encoder(
WorkflowToolManageService.list_single_workflow_tools(
@ -699,18 +718,15 @@ class ToolLabelsApi(Resource):
class ToolPluginOAuthApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
# todo check permission
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
@ -790,37 +806,43 @@ class ToolOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_default_cred = reqparse.RequestParser().add_argument(
"id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
class ToolBuiltinProviderSetDefaultApi(Resource):
@api.expect(parser_default_cred)
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
args = parser_default_cred.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
)
parser_custom = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
class ToolOAuthCustomClient(Resource):
@api.expect(parser_custom)
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
parser = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args()
def post(self, provider: str):
args = parser_custom.parse_args()
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
_, tenant_id = current_account_with_tenant()
return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=tenant_id,
@ -878,25 +900,44 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
)
parser_mcp = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
parser_mcp_put = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
parser_mcp_delete = reqparse.RequestParser().add_argument(
"provider_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/tool-provider/mcp")
class ToolProviderMCPApi(Resource):
@api.expect(parser_mcp)
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args()
args = parser_mcp.parse_args()
user, tenant_id = current_account_with_tenant()
# Parse and validate models
@ -921,24 +962,12 @@ class ToolProviderMCPApi(Resource):
)
return jsonable_encoder(result)
@api.expect(parser_mcp_put)
@setup_required
@login_required
@account_initialization_required
def put(self):
parser = (
reqparse.RequestParser()
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("icon", type=str, required=True, nullable=False, location="json")
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args()
args = parser_mcp_put.parse_args()
configuration = MCPConfiguration.model_validate(args["configuration"])
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
_, current_tenant_id = current_account_with_tenant()
@ -972,14 +1001,12 @@ class ToolProviderMCPApi(Resource):
)
return {"result": "success"}
@api.expect(parser_mcp_delete)
@setup_required
@login_required
@account_initialization_required
def delete(self):
parser = reqparse.RequestParser().add_argument(
"provider_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
args = parser_mcp_delete.parse_args()
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine) as session, session.begin():
@ -988,18 +1015,21 @@ class ToolProviderMCPApi(Resource):
return {"result": "success"}
parser_auth = (
reqparse.RequestParser()
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
class ToolMCPAuthApi(Resource):
@api.expect(parser_auth)
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
args = parser_auth.parse_args()
provider_id = args["provider_id"]
_, tenant_id = current_account_with_tenant()
@ -1035,7 +1065,13 @@ class ToolMCPAuthApi(Resource):
return {"result": "success"}
except MCPAuthError as e:
try:
auth_result = auth(provider_entity, args.get("authorization_code"))
# Pass the extracted OAuth metadata hints to auth()
auth_result = auth(
provider_entity,
args.get("authorization_code"),
resource_metadata_url=e.resource_metadata_url,
scope_hint=e.scope_hint,
)
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result)
@ -1045,7 +1081,7 @@ class ToolMCPAuthApi(Resource):
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except MCPError as e:
except (MCPError, ValueError) as e:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
@ -1097,15 +1133,18 @@ class ToolMCPUpdateApi(Resource):
return jsonable_encoder(tools)
parser_cb = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, nullable=False, location="args")
.add_argument("state", type=str, required=True, nullable=False, location="args")
)
@console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource):
@api.expect(parser_cb)
def get(self):
parser = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, nullable=False, location="args")
.add_argument("state", type=str, required=True, nullable=False, location="args")
)
args = parser.parse_args()
args = parser_cb.parse_args()
state_key = args["state"]
authorization_code = args["code"]

View File

@ -7,7 +7,7 @@ from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@ -67,14 +67,12 @@ class TriggerProviderInfoApi(Resource):
class TriggerSubscriptionListApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
return jsonable_encoder(
@ -92,17 +90,16 @@ class TriggerSubscriptionListApi(Resource):
class TriggerSubscriptionBuilderCreateApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
"""Add a new subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
parser = reqparse.RequestParser().add_argument(
"credential_type", type=str, required=False, nullable=True, location="json"
)
args = parser.parse_args()
try:
@ -133,18 +130,17 @@ class TriggerSubscriptionBuilderGetApi(Resource):
class TriggerSubscriptionBuilderVerifyApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
args = parser.parse_args()
try:
@ -173,15 +169,17 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
assert isinstance(user, Account)
assert user.current_tenant_id is not None
parser = reqparse.RequestParser()
# The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
# The name of the subscription builder
.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
args = parser.parse_args()
try:
return jsonable_encoder(
@ -223,24 +221,23 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
class TriggerSubscriptionBuilderBuildApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
# The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
# The name of the subscription builder
.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
args = parser.parse_args()
try:
# Use atomic update_and_build to prevent race conditions
@ -264,14 +261,12 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
class TriggerSubscriptionDeleteApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, subscription_id: str):
"""Delete a subscription instance"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
with Session(db.engine) as session:
@ -446,14 +441,12 @@ class TriggerOAuthCallbackApi(Resource):
class TriggerOAuthClientManageApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider):
"""Get OAuth client configuration for a provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
provider_id = TriggerProviderID(provider)
@ -493,18 +486,18 @@ class TriggerOAuthClientManageApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
"""Configure custom OAuth client for a provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
parser = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args()
try:
@ -524,14 +517,12 @@ class TriggerOAuthClientManageApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider):
"""Remove custom OAuth client configuration"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
provider_id = TriggerProviderID(provider)

View File

@ -13,7 +13,7 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console import console_ns
from controllers.console import api, console_ns
from controllers.console.admin import admin_required
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
@ -128,7 +128,7 @@ class TenantApi(Resource):
@login_required
@account_initialization_required
@marshal_with(tenant_fields)
def get(self):
def post(self):
if request.path == "/info":
logger.warning("Deprecated URL /info was used.")
@ -150,15 +150,18 @@ class TenantApi(Resource):
return WorkspaceService.get_tenant_info(tenant), 200
parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
@console_ns.route("/workspaces/switch")
class SwitchWorkspaceApi(Resource):
@api.expect(parser_switch)
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_switch.parse_args()
# check if tenant_id is valid, 403 if not
try:
@ -242,16 +245,19 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201
parser_info = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
@console_ns.route("/workspaces/info")
class WorkspaceInfoApi(Resource):
@api.expect(parser_info)
@setup_required
@login_required
@account_initialization_required
# Change workspace name
def post(self):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
args = parser_info.parse_args()
if not current_tenant_id:
raise ValueError("No current tenant")

View File

@ -315,3 +315,19 @@ def edit_permission_required(f: Callable[P, R]):
return f(*args, **kwargs)
return decorated_function
def is_admin_or_owner_required(f: Callable[P, R]):
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
from libs.login import current_user
from models import Account
user = current_user._get_current_object()
if not isinstance(user, Account) or not user.is_admin_or_owner:
raise Forbidden()
return f(*args, **kwargs)
return decorated_function

View File

@ -3,14 +3,12 @@ from typing import Literal
from flask import request
from flask_restx import Api, Namespace, Resource, fields, reqparse
from flask_restx.api import HTTPStatus
from werkzeug.exceptions import Forbidden
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_fields, build_annotation_model
from libs.login import current_user
from models import Account
from models.model import App
from services.annotation_service import AppAnnotationService
@ -161,14 +159,10 @@ class AnnotationUpdateDeleteApi(Resource):
}
)
@validate_app_token
@edit_permission_required
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
def put(self, app_model: App, annotation_id):
def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation."""
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
annotation_id = str(annotation_id)
args = annotation_create_parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation
@ -185,13 +179,8 @@ class AnnotationUpdateDeleteApi(Resource):
}
)
@validate_app_token
def delete(self, app_model: App, annotation_id):
@edit_permission_required
def delete(self, app_model: App, annotation_id: str):
"""Delete an annotation."""
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
return {"result": "success"}, 204

View File

@ -5,6 +5,7 @@ from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
from controllers.service_api.wraps import (
@ -619,11 +620,9 @@ class DatasetTagsApi(DatasetApiResource):
}
)
@validate_dataset_token
@edit_permission_required
def delete(self, _, dataset_id):
"""Delete a knowledge type tag."""
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
args = tag_delete_parser.parse_args()
TagService.delete_tag(args["tag_id"])

View File

@ -1,7 +1,10 @@
import json
from typing import Self
from uuid import UUID
from flask import request
from flask_restx import marshal, reqparse
from pydantic import BaseModel, model_validator
from sqlalchemy import desc, select
from werkzeug.exceptions import Forbidden, NotFound
@ -31,7 +34,7 @@ from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
# Define parsers for document operations
@ -51,15 +54,26 @@ document_text_create_parser = (
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
document_text_update_parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("text", type=str, required=False, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class DocumentTextUpdate(BaseModel):
name: str | None = None
text: str | None = None
process_rule: ProcessRule | None = None
doc_form: str = "text_model"
doc_language: str = "English"
retrieval_model: RetrievalModel | None = None
@model_validator(mode="after")
def check_text_and_name(self) -> Self:
if self.text is not None and self.name is None:
raise ValueError("name is required when text is provided")
return self
for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
@service_api_ns.route(
@ -160,7 +174,7 @@ class DocumentAddByTextApi(DatasetApiResource):
class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents."""
@service_api_ns.expect(document_text_update_parser)
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True)
@service_api_ns.doc("update_document_by_text")
@service_api_ns.doc(description="Update an existing document by providing text content")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -173,12 +187,10 @@ class DocumentUpdateByTextApi(DatasetApiResource):
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id):
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by text."""
args = document_text_update_parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
if not dataset:
raise ValueError("Dataset does not exist.")
@ -198,11 +210,9 @@ class DocumentUpdateByTextApi(DatasetApiResource):
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
if args["text"]:
if args.get("text"):
text = args.get("text")
name = args.get("name")
if text is None or name is None:
raise ValueError("Both text and name must be strings.")
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
@ -456,12 +466,16 @@ class DocumentListApi(DatasetApiResource):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
status = request.args.get("status", default=None, type=str)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
if status:
query = DocumentService.apply_display_status_filter(query, status)
if search:
search = f"%{search}%"
query = query.where(Document.name.like(search))

View File

@ -88,12 +88,6 @@ class AudioApi(WebApiResource):
@web_ns.route("/text-to-audio")
class TextApi(WebApiResource):
text_to_audio_response_fields = {
"audio_url": fields.String,
"duration": fields.Float,
}
@marshal_with(text_to_audio_response_fields)
@web_ns.doc("Text to Audio")
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
@web_ns.doc(

View File

@ -81,6 +81,7 @@ class LoginStatusApi(Resource):
)
def get(self):
app_code = request.args.get("app_code")
user_id = request.args.get("user_id")
token = extract_webapp_access_token(request)
if not app_code:
return {
@ -103,7 +104,7 @@ class LoginStatusApi(Resource):
user_logged_in = False
try:
_ = decode_jwt_token(app_code=app_code)
_ = decode_jwt_token(app_code=app_code, user_id=user_id)
app_logged_in = True
except Exception:
app_logged_in = False

View File

@ -38,7 +38,7 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None =
return decorator
def decode_jwt_token(app_code: str | None = None):
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
system_features = FeatureService.get_system_features()
if not app_code:
app_code = str(request.headers.get(HEADER_NAME_APP_CODE))
@ -63,6 +63,10 @@ def decode_jwt_token(app_code: str | None = None):
if not end_user:
raise NotFound()
# Validate user_id against end_user's session_id if provided
if user_id is not None and end_user.session_id != user_id:
raise Unauthorized("Authentication has expired.")
# for enterprise webapp auth
app_web_auth_enabled = False
webapp_settings = None

View File

@ -163,7 +163,7 @@ class PipelineGenerator(BaseAppGenerator):
datasource_type=datasource_type,
datasource_info=json.dumps(datasource_info),
datasource_node_id=start_node_id,
input_data=inputs,
input_data=dict(inputs),
pipeline_id=pipeline.id,
created_by=user.id,
)

View File

@ -145,7 +145,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
# for trigger debug run, not prepare user inputs
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs
if self._should_prepare_user_inputs(args):
inputs = self._prepare_user_inputs(
user_inputs=inputs,

View File

@ -644,14 +644,15 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if not workflow_run_id:
return
workflow_app_log = WorkflowAppLog()
workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
workflow_app_log.workflow_id = self._workflow.id
workflow_app_log.workflow_run_id = workflow_run_id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user_id
workflow_app_log = WorkflowAppLog(
tenant_id=self._application_generate_entity.app_config.tenant_id,
app_id=self._application_generate_entity.app_config.app_id,
workflow_id=self._workflow.id,
workflow_run_id=workflow_run_id,
created_from=created_from.value,
created_by_role=self._created_by_role,
created_by=self._user_id,
)
session.add(workflow_app_log)
session.commit()

View File

@ -1,14 +1,10 @@
from typing import TYPE_CHECKING, Any, Optional
from typing import Any
from pydantic import BaseModel, Field
# Import InvokeFrom locally to avoid circular import
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
if TYPE_CHECKING:
from core.app.entities.app_invoke_entities import InvokeFrom
class DatasourceRuntime(BaseModel):
"""
@ -17,7 +13,7 @@ class DatasourceRuntime(BaseModel):
tenant_id: str
datasource_id: str | None = None
invoke_from: Optional["InvokeFrom"] = None
invoke_from: InvokeFrom | None = None
datasource_invoke_from: DatasourceInvokeFrom | None = None
credentials: dict[str, Any] = Field(default_factory=dict)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)

View File

@ -6,7 +6,8 @@ import secrets
import urllib.parse
from urllib.parse import urljoin, urlparse
from httpx import ConnectError, HTTPStatusError, RequestError
import httpx
from httpx import RequestError
from pydantic import ValidationError
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
@ -20,6 +21,7 @@ from core.mcp.types import (
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
ProtectedResourceMetadata,
)
from extensions.ext_redis import redis_client
@ -39,6 +41,131 @@ def generate_pkce_challenge() -> tuple[str, str]:
return code_verifier, code_challenge
def build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url: str | None, server_url: str
) -> list[str]:
"""
Build a list of URLs to try for Protected Resource Metadata discovery.
Per SEP-985, supports fallback when discovery fails at one URL.
"""
urls = []
# First priority: URL from WWW-Authenticate header
if www_auth_resource_metadata_url:
urls.append(www_auth_resource_metadata_url)
# Fallback: construct from server URL
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
if fallback_url not in urls:
urls.append(fallback_url)
return urls
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
"""
Build a list of URLs to try for OAuth Authorization Server Metadata discovery.
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
Per RFC 8414 section 3:
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
Example:
- issuer: https://example.com/oauth
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
"""
urls = []
base_url = auth_server_url or server_url
parsed = urlparse(base_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/") # Remove trailing slash
# Try OpenID Connect discovery first (more common)
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
# Include the path component if present in the issuer URL
if path:
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
else:
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
return urls
def discover_protected_resource_metadata(
prm_url: str | None, server_url: str, protocol_version: str | None = None
) -> ProtectedResourceMetadata | None:
"""Discover OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
urls = build_protected_resource_metadata_discovery_urls(prm_url, server_url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
for url in urls:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 200:
return ProtectedResourceMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
continue # Try next URL
return None
def discover_oauth_authorization_server_metadata(
auth_server_url: str | None, server_url: str, protocol_version: str | None = None
) -> OAuthMetadata | None:
"""Discover OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
for url in urls:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 200:
return OAuthMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
continue # Try next URL
return None
def get_effective_scope(
scope_from_www_auth: str | None,
prm: ProtectedResourceMetadata | None,
asm: OAuthMetadata | None,
client_scope: str | None,
) -> str | None:
"""
Determine effective scope using priority-based selection strategy.
Priority order:
1. WWW-Authenticate header scope (server explicit requirement)
2. Protected Resource Metadata scopes
3. OAuth Authorization Server Metadata scopes
4. Client configured scope
"""
if scope_from_www_auth:
return scope_from_www_auth
if prm and prm.scopes_supported:
return " ".join(prm.scopes_supported)
if asm and asm.scopes_supported:
return " ".join(asm.scopes_supported)
return client_scope
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
# Generate a secure random state key
@ -121,42 +248,36 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
return False, ""
def discover_oauth_metadata(server_url: str, protocol_version: str | None = None) -> OAuthMetadata | None:
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
# First check if the server supports OAuth 2.0 Resource Discovery
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
if support_resource_discovery:
# The oauth_discovery_url is the authorization server base URL
# Try OpenID Connect discovery first (more common), then OAuth 2.0
urls_to_try = [
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
]
else:
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
def discover_oauth_metadata(
server_url: str,
resource_metadata_url: str | None = None,
scope_hint: str | None = None,
protocol_version: str | None = None,
) -> tuple[OAuthMetadata | None, ProtectedResourceMetadata | None, str | None]:
"""
Discover OAuth metadata using RFC 8414/9470 standards.
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
Args:
server_url: The MCP server URL
resource_metadata_url: Protected Resource Metadata URL from WWW-Authenticate header
scope_hint: Scope hint from WWW-Authenticate header
protocol_version: MCP protocol version
for url in urls_to_try:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 404:
continue
if not response.is_success:
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
except (RequestError, HTTPStatusError) as e:
if isinstance(e, ConnectError):
response = ssrf_proxy.get(url)
if response.status_code == 404:
continue # Try next URL
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
# For other errors, try next URL
continue
Returns:
(oauth_metadata, protected_resource_metadata, scope_hint)
"""
# Discover Protected Resource Metadata
prm = discover_protected_resource_metadata(resource_metadata_url, server_url, protocol_version)
return None # No metadata found
# Get authorization server URL from PRM or use server URL
auth_server_url = None
if prm and prm.authorization_servers:
auth_server_url = prm.authorization_servers[0]
# Discover OAuth Authorization Server Metadata
asm = discover_oauth_authorization_server_metadata(auth_server_url, server_url, protocol_version)
return asm, prm, scope_hint
def start_authorization(
@ -166,6 +287,7 @@ def start_authorization(
redirect_url: str,
provider_id: str,
tenant_id: str,
scope: str | None = None,
) -> tuple[str, str]:
"""Begins the authorization flow with secure Redis state storage."""
response_type = "code"
@ -175,13 +297,6 @@ def start_authorization(
authorization_url = metadata.authorization_endpoint
if response_type not in metadata.response_types_supported:
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
if (
not metadata.code_challenge_methods_supported
or code_challenge_method not in metadata.code_challenge_methods_supported
):
raise ValueError(
f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
)
else:
authorization_url = urljoin(server_url, "/authorize")
@ -210,10 +325,49 @@ def start_authorization(
"state": state_key,
}
# Add scope if provided
if scope:
params["scope"] = scope
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
return authorization_url, code_verifier
def _parse_token_response(response: httpx.Response) -> OAuthTokens:
"""
Parse OAuth token response supporting both JSON and form-urlencoded formats.
Per RFC 6749 Section 5.1, the standard format is JSON.
However, some legacy OAuth providers (e.g., early GitHub OAuth Apps) return
application/x-www-form-urlencoded format for backwards compatibility.
Args:
response: The HTTP response from token endpoint
Returns:
Parsed OAuth tokens
Raises:
ValueError: If response cannot be parsed
"""
content_type = response.headers.get("content-type", "").lower()
if "application/json" in content_type:
# Standard OAuth 2.0 JSON response (RFC 6749)
return OAuthTokens.model_validate(response.json())
elif "application/x-www-form-urlencoded" in content_type:
# Legacy form-urlencoded response (non-standard but used by some providers)
token_data = dict(urllib.parse.parse_qsl(response.text))
return OAuthTokens.model_validate(token_data)
else:
# No content-type or unknown - try JSON first, fallback to form-urlencoded
try:
return OAuthTokens.model_validate(response.json())
except (ValidationError, json.JSONDecodeError):
token_data = dict(urllib.parse.parse_qsl(response.text))
return OAuthTokens.model_validate(token_data)
def exchange_authorization(
server_url: str,
metadata: OAuthMetadata | None,
@ -246,7 +400,7 @@ def exchange_authorization(
response = ssrf_proxy.post(token_url, data=params)
if not response.is_success:
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
return _parse_token_response(response)
def refresh_authorization(
@ -279,7 +433,7 @@ def refresh_authorization(
raise MCPRefreshTokenError(e) from e
if not response.is_success:
raise MCPRefreshTokenError(response.text)
return OAuthTokens.model_validate(response.json())
return _parse_token_response(response)
def client_credentials_flow(
@ -322,7 +476,7 @@ def client_credentials_flow(
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
)
return OAuthTokens.model_validate(response.json())
return _parse_token_response(response)
def register_client(
@ -352,6 +506,8 @@ def auth(
provider: MCPProviderEntity,
authorization_code: str | None = None,
state_param: str | None = None,
resource_metadata_url: str | None = None,
scope_hint: str | None = None,
) -> AuthResult:
"""
Orchestrates the full auth flow with a server using secure Redis state storage.
@ -363,18 +519,26 @@ def auth(
provider: The MCP provider entity
authorization_code: Optional authorization code from OAuth callback
state_param: Optional state parameter from OAuth callback
resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate
scope_hint: Optional scope hint from WWW-Authenticate header
Returns:
AuthResult containing actions to be performed and response data
"""
actions: list[AuthAction] = []
server_url = provider.decrypt_server_url()
server_metadata = discover_oauth_metadata(server_url)
# Discover OAuth metadata using RFC 8414/9470 standards
server_metadata, prm, scope_from_www_auth = discover_oauth_metadata(
server_url, resource_metadata_url, scope_hint, LATEST_PROTOCOL_VERSION
)
client_metadata = provider.client_metadata
provider_id = provider.id
tenant_id = provider.tenant_id
client_information = provider.retrieve_client_information()
redirect_url = provider.redirect_url
credentials = provider.decrypt_credentials()
# Determine grant type based on server metadata
if not server_metadata:
@ -392,8 +556,8 @@ def auth(
else:
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
# Get stored credentials
credentials = provider.decrypt_credentials()
# Determine effective scope using priority-based strategy
effective_scope = get_effective_scope(scope_from_www_auth, prm, server_metadata, credentials.get("scope"))
if not client_information:
if authorization_code is not None:
@ -425,12 +589,11 @@ def auth(
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Direct token request without user interaction
try:
scope = credentials.get("scope")
tokens = client_credentials_flow(
server_url,
server_metadata,
client_information,
scope,
effective_scope,
)
# Return action to save tokens and grant type
@ -526,6 +689,7 @@ def auth(
redirect_url,
provider_id,
tenant_id,
effective_scope,
)
# Return action to save code verifier

View File

@ -90,7 +90,13 @@ class MCPClientWithAuthRetry(MCPClient):
mcp_service = MCPToolManageService(session=session)
# Perform authentication using the service's auth method
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
# Extract OAuth metadata hints from the error
mcp_service.auth_with_actions(
self.provider_entity,
self.authorization_code,
resource_metadata_url=error.resource_metadata_url,
scope_hint=error.scope_hint,
)
# Retrieve new tokens
self.provider_entity = mcp_service.get_provider_entity(

View File

@ -290,7 +290,7 @@ def sse_client(
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPAuthError(response=exc.response)
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")

View File

@ -138,6 +138,10 @@ class StreamableHTTPTransport:
) -> bool:
"""Handle an SSE event, returning True if the response is complete."""
if sse.event == "message":
# ping event send by server will be recognized as a message event with empty data by httpx-sse's SSEDecoder
if not sse.data.strip():
return False
try:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug("SSE message: %s", message)

View File

@ -1,3 +1,10 @@
import re
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import httpx
class MCPError(Exception):
pass
@ -7,7 +14,49 @@ class MCPConnectionError(MCPError):
class MCPAuthError(MCPConnectionError):
pass
def __init__(
self,
message: str | None = None,
response: "httpx.Response | None" = None,
www_authenticate_header: str | None = None,
):
"""
MCP Authentication Error.
Args:
message: Error message
response: HTTP response object (will extract WWW-Authenticate header if provided)
www_authenticate_header: Pre-extracted WWW-Authenticate header value
"""
super().__init__(message or "Authentication failed")
# Extract OAuth metadata hints from WWW-Authenticate header
if response is not None:
www_authenticate_header = response.headers.get("WWW-Authenticate")
self.resource_metadata_url: str | None = None
self.scope_hint: str | None = None
if www_authenticate_header:
self.resource_metadata_url = self._extract_field(www_authenticate_header, "resource_metadata")
self.scope_hint = self._extract_field(www_authenticate_header, "scope")
@staticmethod
def _extract_field(www_auth: str, field_name: str) -> str | None:
"""Extract a specific field from the WWW-Authenticate header."""
# Pattern to match field="value" or field=value
pattern = rf'{field_name}="([^"]*)"'
match = re.search(pattern, www_auth)
if match:
return match.group(1)
# Try without quotes
pattern = rf"{field_name}=([^\s,]+)"
match = re.search(pattern, www_auth)
if match:
return match.group(1)
return None
class MCPRefreshTokenError(MCPError):

View File

@ -149,7 +149,7 @@ class BaseSession(
messages when entered.
"""
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError]]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_receive_request_type: type[ReceiveRequestT]
@ -230,7 +230,7 @@ class BaseSession(
request_id = self._request_id
self._request_id = request_id + 1
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError] = queue.Queue()
self._response_streams[request_id] = response_queue
try:
@ -261,11 +261,17 @@ class BaseSession(
message="No response received",
)
)
elif isinstance(response_or_error, HTTPStatusError):
# HTTPStatusError from streamable_client with preserved response object
if response_or_error.response.status_code == 401:
raise MCPAuthError(response=response_or_error.response)
else:
raise MCPConnectionError(
ErrorData(code=response_or_error.response.status_code, message=str(response_or_error))
)
elif isinstance(response_or_error, JSONRPCError):
if response_or_error.error.code == 401:
raise MCPAuthError(
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
)
raise MCPAuthError(message=response_or_error.error.message)
else:
raise MCPConnectionError(
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
@ -327,13 +333,17 @@ class BaseSession(
if isinstance(message, HTTPStatusError):
response_queue = self._response_streams.get(self._request_id - 1)
if response_queue is not None:
response_queue.put(
JSONRPCError(
jsonrpc="2.0",
id=self._request_id - 1,
error=ErrorData(code=message.response.status_code, message=message.args[0]),
# For 401 errors, pass the HTTPStatusError directly to preserve response object
if message.response.status_code == 401:
response_queue.put(message)
else:
response_queue.put(
JSONRPCError(
jsonrpc="2.0",
id=self._request_id - 1,
error=ErrorData(code=message.response.status_code, message=message.args[0]),
)
)
)
else:
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
elif isinstance(message, Exception):

View File

@ -23,7 +23,7 @@ for reference.
not separate types in the schema.
"""
# Client support both version, not support 2025-06-18 yet.
LATEST_PROTOCOL_VERSION = "2025-03-26"
LATEST_PROTOCOL_VERSION = "2025-06-18"
# Server support 2024-11-05 to allow claude to use.
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
@ -1330,3 +1330,13 @@ class OAuthMetadata(BaseModel):
response_types_supported: list[str]
grant_types_supported: list[str] | None = None
code_challenge_methods_supported: list[str] | None = None
scopes_supported: list[str] | None = None
class ProtectedResourceMetadata(BaseModel):
"""OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
resource: str | None = None
authorization_servers: list[str]
scopes_supported: list[str] | None = None
bearer_methods_supported: list[str] | None = None

View File

@ -52,7 +52,7 @@ class OpenAIModeration(Moderation):
text = "\n".join(str(inputs.values()))
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable"
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest"
)
openai_moderation = model_instance.invoke_moderation(text=text)

View File

@ -274,6 +274,8 @@ class OpsTraceManager:
raise ValueError("App not found")
tenant_id = app.tenant_id
if trace_config_data.tracing_config is None:
raise ValueError("Tracing config cannot be None.")
decrypt_tracing_config = cls.decrypt_tracing_config(
tenant_id, tracing_provider, trace_config_data.tracing_config
)

View File

@ -1,12 +1,20 @@
import logging
import os
import uuid
from datetime import datetime, timedelta
from datetime import UTC, datetime, timedelta
from typing import Any, cast
import wandb
import weave
from sqlalchemy.orm import sessionmaker
from weave.trace_server.trace_server_interface import (
CallEndReq,
CallStartReq,
EndedCallSchemaForInsert,
StartedCallSchemaForInsert,
SummaryInsertMap,
TraceStatus,
)
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import WeaveConfig
@ -57,6 +65,7 @@ class WeaveDataTrace(BaseTraceInstance):
)
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.calls: dict[str, Any] = {}
self.project_id = f"{self.weave_client.entity}/{self.weave_client.project}"
def get_project_url(
self,
@ -424,6 +433,13 @@ class WeaveDataTrace(BaseTraceInstance):
logger.debug("Weave API check failed: %s", str(e))
raise ValueError(f"Weave API check failed: {str(e)}")
def _normalize_time(self, dt: datetime | None) -> datetime:
if dt is None:
return datetime.now(UTC)
if dt.tzinfo is None:
return dt.replace(tzinfo=UTC)
return dt
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
inputs = run_data.inputs
if inputs is None:
@ -437,19 +453,71 @@ class WeaveDataTrace(BaseTraceInstance):
elif not isinstance(attributes, dict):
attributes = {"attributes": str(attributes)}
call = self.weave_client.create_call(
op=run_data.op,
inputs=inputs,
attributes=attributes,
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
trace_id = attributes.get("trace_id") if isinstance(attributes, dict) else None
if trace_id is None:
trace_id = run_data.id
call_start_req = CallStartReq(
start=StartedCallSchemaForInsert(
project_id=self.project_id,
id=run_data.id,
op_name=str(run_data.op),
trace_id=trace_id,
parent_id=parent_run_id,
started_at=started_at,
attributes=attributes,
inputs=inputs,
wb_user_id=None,
)
)
self.calls[run_data.id] = call
if parent_run_id:
self.calls[run_data.id].parent_id = parent_run_id
self.weave_client.server.call_start(call_start_req)
self.calls[run_data.id] = {"trace_id": trace_id, "parent_id": parent_run_id}
def finish_call(self, run_data: WeaveTraceModel):
call = self.calls.get(run_data.id)
if call:
exception = Exception(run_data.exception) if run_data.exception else None
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=exception)
else:
call_meta = self.calls.get(run_data.id)
if not call_meta:
raise ValueError(f"Call with id {run_data.id} not found")
attributes = run_data.attributes
if attributes is None:
attributes = {}
elif not isinstance(attributes, dict):
attributes = {"attributes": str(attributes)}
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
end_time = attributes.get("end_time") if isinstance(attributes, dict) else None
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
ended_at = self._normalize_time(end_time if isinstance(end_time, datetime) else None)
elapsed_ms = int((ended_at - started_at).total_seconds() * 1000)
if elapsed_ms < 0:
elapsed_ms = 0
status_counts = {
TraceStatus.SUCCESS: 0,
TraceStatus.ERROR: 0,
}
if run_data.exception:
status_counts[TraceStatus.ERROR] = 1
else:
status_counts[TraceStatus.SUCCESS] = 1
summary: dict[str, Any] = {
"status_counts": status_counts,
"weave": {"latency_ms": elapsed_ms},
}
exception_str = str(run_data.exception) if run_data.exception else None
call_end_req = CallEndReq(
end=EndedCallSchemaForInsert(
project_id=self.project_id,
id=run_data.id,
ended_at=ended_at,
exception=exception_str,
output=run_data.outputs,
summary=cast(SummaryInsertMap, summary),
)
)
self.weave_client.server.call_end(call_end_req)

View File

@ -309,11 +309,12 @@ class ProviderManager:
(model for model in available_models if model.model == "gpt-4"), available_models[0]
)
default_model = TenantDefaultModel()
default_model.tenant_id = tenant_id
default_model.model_type = model_type.to_origin_model_type()
default_model.provider_name = available_model.provider.provider
default_model.model_name = available_model.model
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.to_origin_model_type(),
provider_name=available_model.provider.provider,
model_name=available_model.model,
)
db.session.add(default_model)
db.session.commit()

View File

@ -22,6 +22,18 @@ logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T", bound="MatrixoneVector")
def ensure_client(func: Callable[Concatenate[T, P], R]):
@wraps(func)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)
return wrapper
class MatrixoneConfig(BaseModel):
host: str = "localhost"
@ -206,19 +218,6 @@ class MatrixoneVector(BaseVector):
self.client.delete()
T = TypeVar("T", bound=MatrixoneVector)
def ensure_client(func: Callable[Concatenate[T, P], R]):
@wraps(func)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)
return wrapper
class MatrixoneVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
if dataset.index_struct_dict:

View File

@ -152,13 +152,15 @@ class WordExtractor(BaseExtractor):
# Initialize a row, all of which are empty by default
row_cells = [""] * total_cols
col_index = 0
for cell in row.cells:
while col_index < len(row.cells):
# make sure the col_index is not out of range
while col_index < total_cols and row_cells[col_index] != "":
while col_index < len(row.cells) and row_cells[col_index] != "":
col_index += 1
# if col_index is out of range the loop is jumped
if col_index >= total_cols:
if col_index >= len(row.cells):
break
# get the correct cell
cell = row.cells[col_index]
cell_content = self._parse_cell(cell, image_map).strip()
cell_colspan = cell.grid_span or 1
for i in range(cell_colspan):

View File

@ -54,6 +54,9 @@ class TenantIsolatedTaskQueue:
serialized_data = wrapper.serialize()
serialized_tasks.append(serialized_data)
if not serialized_tasks:
return
redis_client.lpush(self._queue, *serialized_tasks)
def pull_tasks(self, count: int = 1) -> Sequence[Any]:

View File

@ -7,8 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import Float, and_, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy import and_, or_, select
from core.app.app_config.entities import (
DatasetEntity,
@ -1023,60 +1022,55 @@ class DatasetRetrieval:
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
):
if value is None and condition not in ("empty", "not empty"):
return
return filters
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
key = f"{metadata_name}_{sequence}"
key_value = f"{metadata_name}_{sequence}_value"
match condition:
case "contains":
filters.append(
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}%"}
)
)
filters.append(json_field.like(f"%{value}%"))
case "not contains":
filters.append(
(text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}%"}
)
)
filters.append(json_field.notlike(f"%{value}%"))
case "start with":
filters.append(
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"{value}%"}
)
)
filters.append(json_field.like(f"{value}%"))
case "end with":
filters.append(
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}"}
)
)
filters.append(json_field.like(f"%{value}"))
case "is" | "=":
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
else:
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) == value)
filters.append(json_field == value)
elif isinstance(value, (int, float)):
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() == value)
case "is not" | "":
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
else:
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) != value)
filters.append(json_field != value)
elif isinstance(value, (int, float)):
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() != value)
case "empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) < value)
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() < value)
case "after" | ">":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) > value)
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() > value)
case "" | "<=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) <= value)
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() <= value)
case "" | ">=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) >= value)
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
case _:
pass
return filters
def _fetch_model_config(

View File

@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
from yarl import URL
import contexts
from configs import dify_config
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
@ -32,7 +33,6 @@ from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
@ -63,7 +63,6 @@ from services.tools.tools_transform_service import ToolTransformService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from core.workflow.runtime import VariablePool
logger = logging.getLogger(__name__)
@ -618,12 +617,28 @@ class ToolManager:
"""
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# for compatibility with old version
sql = """
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
# PostgreSQL: Use DISTINCT ON
sql = """
SELECT DISTINCT ON (tenant_id, provider) id
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
else:
# MySQL: Use window function to achieve same result
sql = """
SELECT id FROM (
SELECT id,
ROW_NUMBER() OVER (
PARTITION BY tenant_id, provider
ORDER BY is_default DESC, created_at DESC
) as rn
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
) ranked WHERE rn = 1
"""
with Session(db.engine, autoflush=False) as session:
ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()

View File

@ -1,9 +1,12 @@
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
from core.file.models import File
if TYPE_CHECKING:
pass
class ArrayValidation(StrEnum):
"""Strategy for validating array elements.
@ -155,6 +158,17 @@ class SegmentType(StrEnum):
return isinstance(value, File)
elif self == SegmentType.NONE:
return value is None
elif self == SegmentType.GROUP:
from .segment_group import SegmentGroup
from .segments import Segment
if isinstance(value, SegmentGroup):
return all(isinstance(item, Segment) for item in value.value)
if isinstance(value, list):
return all(isinstance(item, Segment) for item in value)
return False
else:
raise AssertionError("this statement should be unreachable.")
@ -202,6 +216,35 @@ class SegmentType(StrEnum):
raise ValueError(f"element_type is only supported by array type, got {self}")
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
@staticmethod
def get_zero_value(t: "SegmentType"):
# Lazy import to avoid circular dependency
from factories import variable_factory
match t:
case (
SegmentType.ARRAY_OBJECT
| SegmentType.ARRAY_ANY
| SegmentType.ARRAY_STRING
| SegmentType.ARRAY_NUMBER
| SegmentType.ARRAY_BOOLEAN
):
return variable_factory.build_segment_with_type(t, [])
case SegmentType.OBJECT:
return variable_factory.build_segment({})
case SegmentType.STRING:
return variable_factory.build_segment("")
case SegmentType.INTEGER:
return variable_factory.build_segment(0)
case SegmentType.FLOAT:
return variable_factory.build_segment(0.0)
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case SegmentType.BOOLEAN:
return variable_factory.build_segment(False)
case _:
raise ValueError(f"unsupported variable type: {t}")
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
# ARRAY_ANY does not have corresponding element type.

View File

@ -192,7 +192,6 @@ class GraphEngine:
self._dispatcher = Dispatcher(
event_queue=self._event_queue,
event_handler=self._event_handler_registry,
event_collector=self._event_manager,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
)

View File

@ -43,7 +43,6 @@ class Dispatcher:
self,
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
event_collector: EventManager,
execution_coordinator: ExecutionCoordinator,
event_emitter: EventManager | None = None,
) -> None:
@ -53,13 +52,11 @@ class Dispatcher:
Args:
event_queue: Queue of events from workers
event_handler: Event handler registry for processing events
event_collector: Event manager for collecting unhandled events
execution_coordinator: Coordinator for execution flow
event_emitter: Optional event manager to signal completion
"""
self._event_queue = event_queue
self._event_handler = event_handler
self._event_collector = event_collector
self._execution_coordinator = execution_coordinator
self._event_emitter = event_emitter
@ -86,37 +83,31 @@ class Dispatcher:
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""
try:
self._process_commands()
while not self._stop_event.is_set():
commands_checked = False
should_check_commands = False
should_break = False
if (
self._execution_coordinator.aborted
or self._execution_coordinator.paused
or self._execution_coordinator.execution_complete
):
break
if self._execution_coordinator.is_execution_complete():
should_check_commands = True
should_break = True
else:
# Check for scaling
self._execution_coordinator.check_scaling()
self._execution_coordinator.check_scaling()
try:
event = self._event_queue.get(timeout=0.1)
self._event_handler.dispatch(event)
self._event_queue.task_done()
self._process_commands(event)
except queue.Empty:
time.sleep(0.1)
# Process events
try:
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self._event_handler.dispatch(event)
should_check_commands = self._should_check_commands(event)
self._event_queue.task_done()
except queue.Empty:
# Process commands even when no new events arrive so abort requests are not missed
should_check_commands = True
time.sleep(0.1)
if should_check_commands and not commands_checked:
self._execution_coordinator.check_commands()
commands_checked = True
if should_break:
if not commands_checked:
self._execution_coordinator.check_commands()
self._process_commands()
while True:
try:
event = self._event_queue.get(block=False)
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
break
except Exception as e:
@ -129,6 +120,6 @@ class Dispatcher:
if self._event_emitter:
self._event_emitter.mark_complete()
def _should_check_commands(self, event: GraphNodeEventBase) -> bool:
"""Return True if the event represents a node completion."""
return isinstance(event, self._COMMAND_TRIGGER_EVENTS)
def _process_commands(self, event: GraphNodeEventBase | None = None):
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
self._execution_coordinator.process_commands()

View File

@ -40,7 +40,7 @@ class ExecutionCoordinator:
self._command_processor = command_processor
self._worker_pool = worker_pool
def check_commands(self) -> None:
def process_commands(self) -> None:
"""Process any pending commands."""
self._command_processor.process_commands()
@ -48,24 +48,16 @@ class ExecutionCoordinator:
"""Check and perform worker scaling if needed."""
self._worker_pool.check_and_scale()
def is_execution_complete(self) -> bool:
"""
Check if execution is complete.
Returns:
True if execution is complete
"""
# Treat paused, aborted, or failed executions as terminal states
if self._graph_execution.is_paused:
return True
if self._graph_execution.aborted or self._graph_execution.has_error:
return True
@property
def execution_complete(self):
return self._state_manager.is_execution_complete()
@property
def is_paused(self) -> bool:
def aborted(self):
return self._graph_execution.aborted or self._graph_execution.has_error
@property
def paused(self) -> bool:
"""Expose whether the underlying graph execution is paused."""
return self._graph_execution.is_paused

View File

@ -6,8 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import Float, and_, func, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -597,79 +596,79 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
if value is None and condition not in ("empty", "not empty"):
return filters
key = f"{metadata_name}_{sequence}"
key_value = f"{metadata_name}_{sequence}_value"
json_field = Document.doc_metadata[metadata_name].as_string()
match condition:
case "contains":
filters.append(
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}%"}
)
)
filters.append(json_field.like(f"%{value}%"))
case "not contains":
filters.append(
(text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}%"}
)
)
filters.append(json_field.notlike(f"%{value}%"))
case "start with":
filters.append(
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"{value}%"}
)
)
filters.append(json_field.like(f"{value}%"))
case "end with":
filters.append(
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}"}
)
)
filters.append(json_field.like(f"%{value}"))
case "in":
if isinstance(value, str):
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
escaped_value_str = ",".join(escaped_values)
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
escaped_value_str = str(value)
filters.append(
(text(f"documents.doc_metadata ->> :{key} = any(string_to_array(:{key_value},','))")).params(
**{key: metadata_name, key_value: escaped_value_str}
)
)
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(False))
else:
filters.append(json_field.in_(value_list))
case "not in":
if isinstance(value, str):
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
escaped_value_str = ",".join(escaped_values)
value_list = [v.strip() for v in value.split(",") if v.strip()]
elif isinstance(value, (list, tuple)):
value_list = [str(v) for v in value if v is not None]
else:
escaped_value_str = str(value)
filters.append(
(text(f"documents.doc_metadata ->> :{key} != all(string_to_array(:{key_value},','))")).params(
**{key: metadata_name, key_value: escaped_value_str}
)
)
case "=" | "is":
value_list = [str(value)] if value is not None else []
if not value_list:
filters.append(literal(True))
else:
filters.append(json_field.notin_(value_list))
case "is" | "=":
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
else:
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) == value)
filters.append(json_field == value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
case "is not" | "":
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
else:
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) != value)
filters.append(json_field != value)
elif isinstance(value, (int, float)):
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(Document.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) < value)
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
case "after" | ">":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) > value)
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
case "" | "<=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) <= value)
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
case "" | ">=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) >= value)
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
case _:
pass
return filters
@classmethod

View File

@ -2,7 +2,6 @@ from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, TypeAlias
from core.variables import SegmentType, Variable
from core.variables.segments import BooleanSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams
@ -12,7 +11,6 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
@ -116,7 +114,7 @@ class VariableAssignerNode(Node):
updated_variable = original_variable.model_copy(update={"value": updated_value})
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
income_value = SegmentType.get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
# Over write the variable.
@ -143,24 +141,3 @@ class VariableAssignerNode(Node):
process_data=common_helpers.set_updated_variables({}, updated_variables),
outputs={},
)
def get_zero_value(t: SegmentType):
# TODO(QuantumGhost): this should be a method of `SegmentType`.
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
return variable_factory.build_segment_with_type(t, [])
case SegmentType.OBJECT:
return variable_factory.build_segment({})
case SegmentType.STRING:
return variable_factory.build_segment("")
case SegmentType.INTEGER:
return variable_factory.build_segment(0)
case SegmentType.FLOAT:
return variable_factory.build_segment(0.0)
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case SegmentType.BOOLEAN:
return BooleanSegment(value=False)
case _:
raise VariableOperatorNodeError(f"unsupported variable type: {t}")

View File

@ -1,14 +0,0 @@
from core.variables import SegmentType
# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy.
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,
SegmentType.BOOLEAN: False,
SegmentType.OBJECT: {},
SegmentType.ARRAY_ANY: [],
SegmentType.ARRAY_STRING: [],
SegmentType.ARRAY_NUMBER: [],
SegmentType.ARRAY_OBJECT: [],
SegmentType.ARRAY_BOOLEAN: [],
}

View File

@ -16,7 +16,6 @@ from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNod
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from . import helpers
from .constants import EMPTY_VALUE_MAPPING
from .entities import VariableAssignerNodeData, VariableOperationItem
from .enums import InputType, Operation
from .exc import (
@ -249,7 +248,7 @@ class VariableAssignerNode(Node):
case Operation.OVER_WRITE:
return value
case Operation.CLEAR:
return EMPTY_VALUE_MAPPING[variable.value_type]
return SegmentType.get_zero_value(variable.value_type).to_object()
case Operation.APPEND:
return variable.value + [value]
case Operation.EXTEND:

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import importlib
import json
from collections.abc import Mapping, Sequence
from collections.abc import Mapping as TypingMapping
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Protocol
@ -100,8 +99,8 @@ class ResponseStreamCoordinatorProtocol(Protocol):
class GraphProtocol(Protocol):
"""Structural interface required from graph instances attached to the runtime state."""
nodes: TypingMapping[str, object]
edges: TypingMapping[str, object]
nodes: Mapping[str, object]
edges: Mapping[str, object]
root_node: object
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...

View File

@ -265,6 +265,45 @@ def _assert_not_empty(*, value: object) -> bool:
return False
def _normalize_numeric_values(value: int | float, expected: object) -> tuple[int | float, int | float]:
"""
Normalize value and expected to compatible numeric types for comparison.
Args:
value: The actual numeric value (int or float)
expected: The expected value (int, float, or str)
Returns:
A tuple of (normalized_value, normalized_expected) with compatible types
Raises:
ValueError: If expected cannot be converted to a number
"""
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to number")
# Convert expected to appropriate numeric type
if isinstance(expected, str):
# Try to convert to float first to handle decimal strings
try:
expected_float = float(expected)
except ValueError as e:
raise ValueError(f"Cannot convert '{expected}' to number") from e
# If value is int and expected is a whole number, keep as int comparison
if isinstance(value, int) and expected_float.is_integer():
return value, int(expected_float)
else:
# Otherwise convert value to float for comparison
return float(value) if isinstance(value, int) else value, expected_float
elif isinstance(expected, float):
# If expected is already float, convert int value to float
return float(value) if isinstance(value, int) else value, expected
else:
# expected is int
return value, expected
def _assert_equal(*, value: object, expected: object) -> bool:
if value is None:
return False
@ -324,18 +363,8 @@ def _assert_greater_than(*, value: object, expected: object) -> bool:
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)
if value <= expected:
return False
return True
value, expected = _normalize_numeric_values(value, expected)
return value > expected
def _assert_less_than(*, value: object, expected: object) -> bool:
@ -345,18 +374,8 @@ def _assert_less_than(*, value: object, expected: object) -> bool:
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)
if value >= expected:
return False
return True
value, expected = _normalize_numeric_values(value, expected)
return value < expected
def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool:
@ -366,18 +385,8 @@ def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool:
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)
if value < expected:
return False
return True
value, expected = _normalize_numeric_values(value, expected)
return value >= expected
def _assert_less_than_or_equal(*, value: object, expected: object) -> bool:
@ -387,18 +396,8 @@ def _assert_less_than_or_equal(*, value: object, expected: object) -> bool:
if not isinstance(value, (int, float)):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)
if value > expected:
return False
return True
value, expected = _normalize_numeric_values(value, expected)
return value <= expected
def _assert_null(*, value: object) -> bool:

View File

@ -421,4 +421,10 @@ class WorkflowEntry:
if len(variable_key_list) == 2 and variable_key_list[0] == "structured_output":
input_value = {variable_key_list[1]: input_value}
variable_key_list = variable_key_list[0:1]
# Support for a single node to reference multiple structured_output variables
current_variable = variable_pool.get([variable_node_id] + variable_key_list)
if current_variable and isinstance(current_variable.value, dict):
input_value = current_variable.value | input_value
variable_pool.add([variable_node_id] + variable_key_list, input_value)

209
api/enums/quota_type.py Normal file
View File

@ -0,0 +1,209 @@
import logging
from dataclasses import dataclass
from enum import StrEnum, auto
logger = logging.getLogger(__name__)
@dataclass
class QuotaCharge:
"""
Result of a quota consumption operation.
Attributes:
success: Whether the quota charge succeeded
charge_id: UUID for refund, or None if failed/disabled
"""
success: bool
charge_id: str | None
_quota_type: "QuotaType"
def refund(self) -> None:
"""
Refund this quota charge.
Safe to call even if charge failed or was disabled.
This method guarantees no exceptions will be raised.
"""
if self.charge_id:
self._quota_type.refund(self.charge_id)
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
class QuotaType(StrEnum):
"""
Supported quota types for tenant feature usage.
Add additional types here whenever new billable features become available.
"""
# Trigger execution quota
TRIGGER = auto()
# Workflow execution quota
WORKFLOW = auto()
UNLIMITED = auto()
@property
def billing_key(self) -> str:
"""
Get the billing key for the feature.
"""
match self:
case QuotaType.TRIGGER:
return "trigger_event"
case QuotaType.WORKFLOW:
return "api_rate_limit"
case _:
raise ValueError(f"Invalid quota type: {self}")
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
"""
Consume quota for the feature.
Args:
tenant_id: The tenant identifier
amount: Amount to consume (default: 1)
Returns:
QuotaCharge with success status and charge_id for refund
Raises:
QuotaExceededError: When quota is insufficient
"""
from configs import dify_config
from services.billing_service import BillingService
from services.errors.app import QuotaExceededError
if not dify_config.BILLING_ENABLED:
logger.debug("Billing disabled, allowing request for %s", tenant_id)
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
if amount <= 0:
raise ValueError("Amount to consume must be greater than 0")
try:
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
if response.get("result") != "success":
logger.warning(
"Failed to consume quota for %s, feature %s details: %s",
tenant_id,
self.value,
response.get("detail"),
)
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
charge_id = response.get("history_id")
logger.debug(
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
amount,
self.value,
tenant_id,
charge_id,
)
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
except QuotaExceededError:
raise
except Exception:
# fail-safe: allow request on billing errors
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
return unlimited()
def check(self, tenant_id: str, amount: int = 1) -> bool:
"""
Check if tenant has sufficient quota without consuming.
Args:
tenant_id: The tenant identifier
amount: Amount to check (default: 1)
Returns:
True if quota is sufficient, False otherwise
"""
from configs import dify_config
if not dify_config.BILLING_ENABLED:
return True
if amount <= 0:
raise ValueError("Amount to check must be greater than 0")
try:
remaining = self.get_remaining(tenant_id)
return remaining >= amount if remaining != -1 else True
except Exception:
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
# fail-safe: allow request on billing errors
return True
def refund(self, charge_id: str) -> None:
"""
Refund quota using charge_id from consume().
This method guarantees no exceptions will be raised.
All errors are logged but silently handled.
Args:
charge_id: The UUID returned from consume()
"""
try:
from configs import dify_config
from services.billing_service import BillingService
if not dify_config.BILLING_ENABLED:
return
if not charge_id:
logger.warning("Cannot refund: charge_id is empty")
return
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
if response.get("result") == "success":
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
else:
logger.warning("Refund failed for charge_id: %s", charge_id)
except Exception:
# Catch ALL exceptions - refund must never fail
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
# Don't raise - refund is best-effort and must be silent
def get_remaining(self, tenant_id: str) -> int:
"""
Get remaining quota for the tenant.
Args:
tenant_id: The tenant identifier
Returns:
Remaining quota amount
"""
from services.billing_service import BillingService
try:
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
if isinstance(usage_info, dict):
return usage_info.get("remaining", 0)
# If it returns a simple number, treat it as remaining
return int(usage_info) if usage_info else 0
except Exception:
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
return -1
def unlimited() -> QuotaCharge:
"""
Return a quota charge for unlimited quota.
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
"""
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)

View File

@ -10,7 +10,6 @@ from redis import RedisError
from redis.cache import CacheConfig
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
from redis.lock import Lock
from redis.sentinel import Sentinel
from configs import dify_config

View File

@ -45,7 +45,6 @@ class ClickZettaVolumeConfig(BaseModel):
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
then fall back to CLICKZETTA_* environment variables (for vector DB config).
"""
import os
# Helper function to get environment variable with fallback
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:

View File

@ -3,7 +3,7 @@ import io
import json
from collections.abc import Generator
from google.cloud import storage as google_cloud_storage
from google.cloud import storage as google_cloud_storage # type: ignore
from configs import dify_config
from extensions.storage.base_storage import BaseStorage

View File

@ -116,6 +116,7 @@ app_partial_fields = {
"access_mode": fields.String,
"create_user_name": fields.String,
"author_name": fields.String,
"has_draft_trigger": fields.Boolean,
}

View File

@ -75,6 +75,7 @@ dataset_detail_fields = {
"document_count": fields.Integer,
"word_count": fields.Integer,
"created_by": fields.String,
"author_name": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,

View File

@ -1,3 +1,4 @@
from .channel import BroadcastChannel
from .sharded_channel import ShardedRedisBroadcastChannel
__all__ = ["BroadcastChannel"]
__all__ = ["BroadcastChannel", "ShardedRedisBroadcastChannel"]

View File

@ -0,0 +1,205 @@
import logging
import queue
import threading
import types
from collections.abc import Generator, Iterator
from typing import Self
from libs.broadcast_channel.channel import Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis.client import PubSub
_logger = logging.getLogger(__name__)
class RedisSubscriptionBase(Subscription):
"""Base class for Redis pub/sub subscriptions with common functionality.
This class provides shared functionality for both regular and sharded
Redis pub/sub subscriptions, reducing code duplication and improving
maintainability.
"""
def __init__(
self,
pubsub: PubSub,
topic: str,
):
# The _pubsub is None only if the subscription is closed.
self._pubsub: PubSub | None = pubsub
self._topic = topic
self._closed = threading.Event()
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
self._dropped_count = 0
self._listener_thread: threading.Thread | None = None
self._start_lock = threading.Lock()
self._started = False
def _start_if_needed(self) -> None:
"""Start the subscription if not already started."""
with self._start_lock:
if self._started:
return
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
if self._pubsub is None:
raise SubscriptionClosedError(
f"The Redis {self._get_subscription_type()} subscription has been cleaned up"
)
self._subscribe()
_logger.debug("Subscribed to %s channel %s", self._get_subscription_type(), self._topic)
self._listener_thread = threading.Thread(
target=self._listen,
name=f"redis-{self._get_subscription_type().replace(' ', '-')}-broadcast-{self._topic}",
daemon=True,
)
self._listener_thread.start()
self._started = True
def _listen(self) -> None:
"""Main listener loop for processing messages."""
pubsub = self._pubsub
assert pubsub is not None, "PubSub should not be None while starting listening."
while not self._closed.is_set():
raw_message = self._get_message()
if raw_message is None:
continue
if raw_message.get("type") != self._get_message_type():
continue
channel_field = raw_message.get("channel")
if isinstance(channel_field, bytes):
channel_name = channel_field.decode("utf-8")
elif isinstance(channel_field, str):
channel_name = channel_field
else:
channel_name = str(channel_field)
if channel_name != self._topic:
_logger.warning(
"Ignoring %s message from unexpected channel %s", self._get_subscription_type(), channel_name
)
continue
payload_bytes: bytes | None = raw_message.get("data")
if not isinstance(payload_bytes, bytes):
_logger.error(
"Received invalid data from %s channel %s, type=%s",
self._get_subscription_type(),
self._topic,
type(payload_bytes),
)
continue
self._enqueue_message(payload_bytes)
_logger.debug("%s listener thread stopped for channel %s", self._get_subscription_type().title(), self._topic)
self._unsubscribe()
pubsub.close()
_logger.debug("%s PubSub closed for topic %s", self._get_subscription_type().title(), self._topic)
self._pubsub = None
def _enqueue_message(self, payload: bytes) -> None:
"""Enqueue a message to the internal queue with dropping behavior."""
while not self._closed.is_set():
try:
self._queue.put_nowait(payload)
return
except queue.Full:
try:
self._queue.get_nowait()
self._dropped_count += 1
_logger.debug(
"Dropped message from Redis %s subscription, topic=%s, total_dropped=%d",
self._get_subscription_type(),
self._topic,
self._dropped_count,
)
except queue.Empty:
continue
return
def _message_iterator(self) -> Generator[bytes, None, None]:
"""Iterator for consuming messages from the subscription."""
while not self._closed.is_set():
try:
item = self._queue.get(timeout=0.1)
except queue.Empty:
continue
yield item
def __iter__(self) -> Iterator[bytes]:
"""Return an iterator over messages from the subscription."""
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
self._start_if_needed()
return iter(self._message_iterator())
def receive(self, timeout: float | None = None) -> bytes | None:
"""Receive the next message from the subscription."""
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
self._start_if_needed()
try:
item = self._queue.get(timeout=timeout)
except queue.Empty:
return None
return item
def __enter__(self) -> Self:
"""Context manager entry point."""
self._start_if_needed()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
"""Context manager exit point."""
self.close()
return None
def close(self) -> None:
"""Close the subscription and clean up resources."""
if self._closed.is_set():
return
self._closed.set()
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the
# message retrieval method should NOT be called concurrently.
#
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
listener = self._listener_thread
if listener is not None:
listener.join(timeout=1.0)
self._listener_thread = None
# Abstract methods to be implemented by subclasses
def _get_subscription_type(self) -> str:
"""Return the subscription type (e.g., 'regular' or 'sharded')."""
raise NotImplementedError
def _subscribe(self) -> None:
"""Subscribe to the Redis topic using the appropriate command."""
raise NotImplementedError
def _unsubscribe(self) -> None:
"""Unsubscribe from the Redis topic using the appropriate command."""
raise NotImplementedError
def _get_message(self) -> dict | None:
"""Get a message from Redis using the appropriate method."""
raise NotImplementedError
def _get_message_type(self) -> str:
"""Return the expected message type (e.g., 'message' or 'smessage')."""
raise NotImplementedError

View File

@ -1,24 +1,15 @@
import logging
import queue
import threading
import types
from collections.abc import Generator, Iterator
from typing import Self
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis
from redis.client import PubSub
_logger = logging.getLogger(__name__)
from ._subscription import RedisSubscriptionBase
class BroadcastChannel:
"""
Redis Pub/Sub based broadcast channel implementation.
Redis Pub/Sub based broadcast channel implementation (regular, non-sharded).
Provides "at most once" delivery semantics for messages published to channels.
Uses Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
Provides "at most once" delivery semantics for messages published to channels
using Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
"""
@ -54,147 +45,23 @@ class Topic:
)
class _RedisSubscription(Subscription):
def __init__(
self,
pubsub: PubSub,
topic: str,
):
# The _pubsub is None only if the subscription is closed.
self._pubsub: PubSub | None = pubsub
self._topic = topic
self._closed = threading.Event()
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
self._dropped_count = 0
self._listener_thread: threading.Thread | None = None
self._start_lock = threading.Lock()
self._started = False
class _RedisSubscription(RedisSubscriptionBase):
"""Regular Redis pub/sub subscription implementation."""
def _start_if_needed(self) -> None:
with self._start_lock:
if self._started:
return
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
if self._pubsub is None:
raise SubscriptionClosedError("The Redis subscription has been cleaned up")
def _get_subscription_type(self) -> str:
return "regular"
self._pubsub.subscribe(self._topic)
_logger.debug("Subscribed to channel %s", self._topic)
def _subscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.subscribe(self._topic)
self._listener_thread = threading.Thread(
target=self._listen,
name=f"redis-broadcast-{self._topic}",
daemon=True,
)
self._listener_thread.start()
self._started = True
def _unsubscribe(self) -> None:
assert self._pubsub is not None
self._pubsub.unsubscribe(self._topic)
def _listen(self) -> None:
pubsub = self._pubsub
assert pubsub is not None, "PubSub should not be None while starting listening."
while not self._closed.is_set():
raw_message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
def _get_message(self) -> dict | None:
assert self._pubsub is not None
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
if raw_message is None:
continue
if raw_message.get("type") != "message":
continue
channel_field = raw_message.get("channel")
if isinstance(channel_field, bytes):
channel_name = channel_field.decode("utf-8")
elif isinstance(channel_field, str):
channel_name = channel_field
else:
channel_name = str(channel_field)
if channel_name != self._topic:
_logger.warning("Ignoring message from unexpected channel %s", channel_name)
continue
payload_bytes: bytes | None = raw_message.get("data")
if not isinstance(payload_bytes, bytes):
_logger.error("Received invalid data from channel %s, type=%s", self._topic, type(payload_bytes))
continue
self._enqueue_message(payload_bytes)
_logger.debug("Listener thread stopped for channel %s", self._topic)
pubsub.unsubscribe(self._topic)
pubsub.close()
_logger.debug("PubSub closed for topic %s", self._topic)
self._pubsub = None
def _enqueue_message(self, payload: bytes) -> None:
while not self._closed.is_set():
try:
self._queue.put_nowait(payload)
return
except queue.Full:
try:
self._queue.get_nowait()
self._dropped_count += 1
_logger.debug(
"Dropped message from Redis subscription, topic=%s, total_dropped=%d",
self._topic,
self._dropped_count,
)
except queue.Empty:
continue
return
def _message_iterator(self) -> Generator[bytes, None, None]:
while not self._closed.is_set():
try:
item = self._queue.get(timeout=0.1)
except queue.Empty:
continue
yield item
def __iter__(self) -> Iterator[bytes]:
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
self._start_if_needed()
return iter(self._message_iterator())
def receive(self, timeout: float | None = None) -> bytes | None:
if self._closed.is_set():
raise SubscriptionClosedError("The Redis subscription is closed")
self._start_if_needed()
try:
item = self._queue.get(timeout=timeout)
except queue.Empty:
return None
return item
def __enter__(self) -> Self:
self._start_if_needed()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
self.close()
return None
def close(self) -> None:
if self._closed.is_set():
return
self._closed.set()
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the `PubSub.get_message`
# method should NOT be called concurrently.
#
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
listener = self._listener_thread
if listener is not None:
listener.join(timeout=1.0)
self._listener_thread = None
def _get_message_type(self) -> str:
return "message"

Some files were not shown because too many files have changed in this diff Show More