Merge branch 'main' into feat/memory-orchestration-be

This commit is contained in:
Stream 2025-11-24 17:03:32 +08:00
commit 47c1da05f2
No known key found for this signature in database
GPG Key ID: 033728094B100D70
612 changed files with 29362 additions and 6806 deletions

View File

@ -29,7 +29,7 @@ trim_trailing_whitespace = false
# Matches multiple files with brace expansion notation # Matches multiple files with brace expansion notation
# Set default charset # Set default charset
[*.{js,tsx}] [*.{js,jsx,ts,tsx,mjs}]
indent_style = space indent_style = space
indent_size = 2 indent_size = 2

View File

@ -62,7 +62,7 @@ jobs:
compose-file: | compose-file: |
docker/docker-compose.middleware.yaml docker/docker-compose.middleware.yaml
services: | services: |
db db_postgres
redis redis
sandbox sandbox
ssrf_proxy ssrf_proxy

View File

@ -28,6 +28,11 @@ jobs:
# Format code # Format code
uv run ruff format .. uv run ruff format ..
- name: count migration progress
run: |
cd api
./cnt_base.sh
- name: ast-grep - name: ast-grep
run: | run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all

View File

@ -8,7 +8,7 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
db-migration-test: db-migration-test-postgres:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
@ -45,7 +45,7 @@ jobs:
compose-file: | compose-file: |
docker/docker-compose.middleware.yaml docker/docker-compose.middleware.yaml
services: | services: |
db db_postgres
redis redis
- name: Prepare configs - name: Prepare configs
@ -57,3 +57,60 @@ jobs:
env: env:
DEBUG: true DEBUG: true
run: uv run --directory api flask upgrade-db run: uv run --directory api flask upgrade-db
db-migration-test-mysql:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
python-version: "3.12"
cache-dependency-glob: api/uv.lock
- name: Install dependencies
run: uv sync --project api
- name: Ensure Offline migration are supported
run: |
# upgrade
uv run --directory api flask db upgrade 'base:head' --sql
# downgrade
uv run --directory api flask db downgrade 'head:base' --sql
- name: Prepare middleware env for MySQL
run: |
cd docker
cp middleware.env.example middleware.env
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db_mysql
redis
- name: Prepare configs for MySQL
run: |
cd api
cp .env.example .env
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' .env
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
- name: Run DB Migration
env:
DEBUG: true
run: uv run --directory api flask upgrade-db

View File

@ -51,13 +51,13 @@ jobs:
- name: Expose Service Ports - name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh run: sh .github/workflows/expose_service_ports.sh
- name: Set up Vector Store (TiDB) # - name: Set up Vector Store (TiDB)
uses: hoverkraft-tech/compose-action@v2.0.2 # uses: hoverkraft-tech/compose-action@v2.0.2
with: # with:
compose-file: docker/tidb/docker-compose.yaml # compose-file: docker/tidb/docker-compose.yaml
services: | # services: |
tidb # tidb
tiflash # tiflash
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase) - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
uses: hoverkraft-tech/compose-action@v2.0.2 uses: hoverkraft-tech/compose-action@v2.0.2
@ -83,8 +83,8 @@ jobs:
ls -lah . ls -lah .
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- name: Check VDB Ready (TiDB) # - name: Check VDB Ready (TiDB)
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py # run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores - name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh run: uv run --project api bash dev/pytest/pytest_vdb.sh

2
.gitignore vendored
View File

@ -186,6 +186,8 @@ docker/volumes/couchbase/*
docker/volumes/oceanbase/* docker/volumes/oceanbase/*
docker/volumes/plugin_daemon/* docker/volumes/plugin_daemon/*
docker/volumes/matrixone/* docker/volumes/matrixone/*
docker/volumes/mysql/*
docker/volumes/seekdb/*
!docker/volumes/oceanbase/init.d !docker/volumes/oceanbase/init.d
docker/nginx/conf.d/default.conf docker/nginx/conf.d/default.conf

View File

@ -37,7 +37,7 @@
"-c", "-c",
"1", "1",
"-Q", "-Q",
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline", "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
"--loglevel", "--loglevel",
"INFO" "INFO"
], ],

View File

@ -70,6 +70,11 @@ type-check:
@uv run --directory api --dev basedpyright @uv run --directory api --dev basedpyright
@echo "✅ Type check complete" @echo "✅ Type check complete"
test:
@echo "🧪 Running backend unit tests..."
@uv run --project api --dev dev/pytest/pytest_unit_tests.sh
@echo "✅ Tests complete"
# Build Docker images # Build Docker images
build-web: build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..." @echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
@ -119,6 +124,7 @@ help:
@echo " make check - Check code with ruff" @echo " make check - Check code with ruff"
@echo " make lint - Format and fix code with ruff" @echo " make lint - Format and fix code with ruff"
@echo " make type-check - Run type checking with basedpyright" @echo " make type-check - Run type checking with basedpyright"
@echo " make test - Run backend unit tests"
@echo "" @echo ""
@echo "Docker Build Targets:" @echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image" @echo " make build-web - Build web Docker image"
@ -128,4 +134,4 @@ help:
@echo " make build-push-all - Build and push all Docker images" @echo " make build-push-all - Build and push all Docker images"
# Phony targets # Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check .PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test

View File

@ -72,12 +72,15 @@ REDIS_CLUSTERS_PASSWORD=
# celery configuration # celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1 CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis CELERY_BACKEND=redis
# PostgreSQL database configuration
# Database configuration
DB_TYPE=postgresql
DB_USERNAME=postgres DB_USERNAME=postgres
DB_PASSWORD=difyai123456 DB_PASSWORD=difyai123456
DB_HOST=localhost DB_HOST=localhost
DB_PORT=5432 DB_PORT=5432
DB_DATABASE=dify DB_DATABASE=dify
SQLALCHEMY_POOL_PRE_PING=true SQLALCHEMY_POOL_PRE_PING=true
SQLALCHEMY_POOL_TIMEOUT=30 SQLALCHEMY_POOL_TIMEOUT=30
@ -159,12 +162,11 @@ SUPABASE_URL=your-server-url
# CORS configuration # CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,* WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Set COOKIE_DOMAIN when the console frontend and API are on different subdomains. # When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the sites top-level domain (e.g., `example.com`). Leading dots are optional.
# Provide the registrable domain (e.g. example.com); leading dots are optional.
COOKIE_DOMAIN= COOKIE_DOMAIN=
# Vector database configuration # Vector database configuration
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. # Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database # Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index VECTOR_INDEX_NAME_PREFIX=Vector_index
@ -175,6 +177,17 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100 WEAVIATE_BATCH_SIZE=100
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=difyai123456
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
OCEANBASE_FULLTEXT_PARSER=ik
SEEKDB_MEMORY_LIMIT=2G
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode # Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
QDRANT_URL=http://localhost:6333 QDRANT_URL=http://localhost:6333
QDRANT_API_KEY=difyai123456 QDRANT_API_KEY=difyai123456
@ -340,15 +353,6 @@ LINDORM_PASSWORD=admin
LINDORM_USING_UGC=True LINDORM_USING_UGC=True
LINDORM_QUERY_TIMEOUT=1 LINDORM_QUERY_TIMEOUT=1
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=difyai123456
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
# AlibabaCloud MySQL Vector configuration # AlibabaCloud MySQL Vector configuration
ALIBABACLOUD_MYSQL_HOST=127.0.0.1 ALIBABACLOUD_MYSQL_HOST=127.0.0.1
ALIBABACLOUD_MYSQL_PORT=3306 ALIBABACLOUD_MYSQL_PORT=3306

View File

@ -15,8 +15,8 @@
```bash ```bash
cd ../docker cd ../docker
cp middleware.env.example middleware.env cp middleware.env.example middleware.env
# change the profile to other vector database if you are not using weaviate # change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
cd ../api cd ../api
``` ```
@ -26,6 +26,10 @@
cp .env.example .env cp .env.example .env
``` ```
> [!IMPORTANT]
>
> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the sites top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
1. Generate a `SECRET_KEY` in the `.env` file. 1. Generate a `SECRET_KEY` in the `.env` file.
bash for Linux bash for Linux
@ -80,7 +84,7 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash ```bash
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
``` ```
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service: Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:

View File

@ -18,6 +18,7 @@ def create_flask_app_with_configs() -> DifyApp:
""" """
dify_app = DifyApp(__name__) dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump()) dify_app.config.from_mapping(dify_config.model_dump())
dify_app.config["RESTX_INCLUDE_ALL_MODELS"] = True
# add before request hook # add before request hook
@dify_app.before_request @dify_app.before_request

7
api/cnt_base.sh Executable file
View File

@ -0,0 +1,7 @@
#!/bin/bash
set -euxo pipefail
for pattern in "Base" "TypeBase"; do
printf "%s " "$pattern"
grep "($pattern):" -r --include='*.py' --exclude-dir=".venv" --exclude-dir="tests" . | wc -l
done

View File

@ -77,10 +77,6 @@ class AppExecutionConfig(BaseSettings):
description="Maximum number of concurrent active requests per app (0 for unlimited)", description="Maximum number of concurrent active requests per app (0 for unlimited)",
default=0, default=0,
) )
APP_DAILY_RATE_LIMIT: NonNegativeInt = Field(
description="Maximum number of requests per app per day",
default=5000,
)
class CodeExecutionSandboxConfig(BaseSettings): class CodeExecutionSandboxConfig(BaseSettings):
@ -1086,7 +1082,7 @@ class CeleryScheduleTasksConfig(BaseSettings):
) )
TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field( TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field(
description="Proactive credential refresh threshold in seconds", description="Proactive credential refresh threshold in seconds",
default=180, default=60 * 60,
) )
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field( TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
description="Proactive subscription refresh threshold in seconds", description="Proactive subscription refresh threshold in seconds",

View File

@ -105,6 +105,12 @@ class KeywordStoreConfig(BaseSettings):
class DatabaseConfig(BaseSettings): class DatabaseConfig(BaseSettings):
# Database type selector
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
description="Database type to use. OceanBase is MySQL-compatible.",
default="postgresql",
)
DB_HOST: str = Field( DB_HOST: str = Field(
description="Hostname or IP address of the database server.", description="Hostname or IP address of the database server.",
default="localhost", default="localhost",
@ -140,10 +146,10 @@ class DatabaseConfig(BaseSettings):
default="", default="",
) )
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field( @computed_field # type: ignore[prop-decorator]
description="Database URI scheme for SQLAlchemy connection.", @property
default="postgresql", def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
) return "postgresql" if self.DB_TYPE == "postgresql" else "mysql+pymysql"
@computed_field # type: ignore[prop-decorator] @computed_field # type: ignore[prop-decorator]
@property @property
@ -204,14 +210,14 @@ class DatabaseConfig(BaseSettings):
# Parse DB_EXTRAS for 'options' # Parse DB_EXTRAS for 'options'
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS)) db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
options = db_extras_dict.get("options", "") options = db_extras_dict.get("options", "")
# Always include timezone connect_args = {}
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
timezone_opt = "-c timezone=UTC" timezone_opt = "-c timezone=UTC"
if options: if options:
# Merge user options and timezone
merged_options = f"{options} {timezone_opt}" merged_options = f"{options} {timezone_opt}"
else: else:
merged_options = timezone_opt merged_options = timezone_opt
connect_args = {"options": merged_options} connect_args = {"options": merged_options}
return { return {

View File

@ -12,7 +12,7 @@ P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.wraps import only_edition_cloud from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db from extensions.ext_database import db
from libs.token import extract_access_token from libs.token import extract_access_token
@ -38,10 +38,10 @@ def admin_required(view: Callable[P, R]):
@console_ns.route("/admin/insert-explore-apps") @console_ns.route("/admin/insert-explore-apps")
class InsertExploreAppListApi(Resource): class InsertExploreAppListApi(Resource):
@api.doc("insert_explore_app") @console_ns.doc("insert_explore_app")
@api.doc(description="Insert or update an app in the explore list") @console_ns.doc(description="Insert or update an app in the explore list")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"InsertExploreAppRequest", "InsertExploreAppRequest",
{ {
"app_id": fields.String(required=True, description="Application ID"), "app_id": fields.String(required=True, description="Application ID"),
@ -55,9 +55,9 @@ class InsertExploreAppListApi(Resource):
}, },
) )
) )
@api.response(200, "App updated successfully") @console_ns.response(200, "App updated successfully")
@api.response(201, "App inserted successfully") @console_ns.response(201, "App inserted successfully")
@api.response(404, "App not found") @console_ns.response(404, "App not found")
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def post(self): def post(self):
@ -131,10 +131,10 @@ class InsertExploreAppListApi(Resource):
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>") @console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
class InsertExploreAppApi(Resource): class InsertExploreAppApi(Resource):
@api.doc("delete_explore_app") @console_ns.doc("delete_explore_app")
@api.doc(description="Remove an app from the explore list") @console_ns.doc(description="Remove an app from the explore list")
@api.doc(params={"app_id": "Application ID to remove"}) @console_ns.doc(params={"app_id": "Application ID to remove"})
@api.response(204, "App removed successfully") @console_ns.response(204, "App removed successfully")
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def delete(self, app_id): def delete(self, app_id):

View File

@ -11,7 +11,7 @@ from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset from models.dataset import Dataset
from models.model import ApiToken, App from models.model import ApiToken, App
from . import api, console_ns from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = { api_key_fields = {
@ -104,14 +104,11 @@ class BaseApiKeyResource(Resource):
resource_model: type | None = None resource_model: type | None = None
resource_id_field: str | None = None resource_id_field: str | None = None
def delete(self, resource_id, api_key_id): def delete(self, resource_id: str, api_key_id: str):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
api_key_id = str(api_key_id)
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model) _get_resource(resource_id, current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
@ -136,20 +133,20 @@ class BaseApiKeyResource(Resource):
@console_ns.route("/apps/<uuid:resource_id>/api-keys") @console_ns.route("/apps/<uuid:resource_id>/api-keys")
class AppApiKeyListResource(BaseApiKeyListResource): class AppApiKeyListResource(BaseApiKeyListResource):
@api.doc("get_app_api_keys") @console_ns.doc("get_app_api_keys")
@api.doc(description="Get all API keys for an app") @console_ns.doc(description="Get all API keys for an app")
@api.doc(params={"resource_id": "App ID"}) @console_ns.doc(params={"resource_id": "App ID"})
@api.response(200, "Success", api_key_list) @console_ns.response(200, "Success", api_key_list)
def get(self, resource_id): def get(self, resource_id): # type: ignore
"""Get all API keys for an app""" """Get all API keys for an app"""
return super().get(resource_id) return super().get(resource_id)
@api.doc("create_app_api_key") @console_ns.doc("create_app_api_key")
@api.doc(description="Create a new API key for an app") @console_ns.doc(description="Create a new API key for an app")
@api.doc(params={"resource_id": "App ID"}) @console_ns.doc(params={"resource_id": "App ID"})
@api.response(201, "API key created successfully", api_key_fields) @console_ns.response(201, "API key created successfully", api_key_fields)
@api.response(400, "Maximum keys exceeded") @console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): def post(self, resource_id): # type: ignore
"""Create a new API key for an app""" """Create a new API key for an app"""
return super().post(resource_id) return super().post(resource_id)
@ -161,10 +158,10 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.route("/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>") @console_ns.route("/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
class AppApiKeyResource(BaseApiKeyResource): class AppApiKeyResource(BaseApiKeyResource):
@api.doc("delete_app_api_key") @console_ns.doc("delete_app_api_key")
@api.doc(description="Delete an API key for an app") @console_ns.doc(description="Delete an API key for an app")
@api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"}) @console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
@api.response(204, "API key deleted successfully") @console_ns.response(204, "API key deleted successfully")
def delete(self, resource_id, api_key_id): def delete(self, resource_id, api_key_id):
"""Delete an API key for an app""" """Delete an API key for an app"""
return super().delete(resource_id, api_key_id) return super().delete(resource_id, api_key_id)
@ -176,20 +173,20 @@ class AppApiKeyResource(BaseApiKeyResource):
@console_ns.route("/datasets/<uuid:resource_id>/api-keys") @console_ns.route("/datasets/<uuid:resource_id>/api-keys")
class DatasetApiKeyListResource(BaseApiKeyListResource): class DatasetApiKeyListResource(BaseApiKeyListResource):
@api.doc("get_dataset_api_keys") @console_ns.doc("get_dataset_api_keys")
@api.doc(description="Get all API keys for a dataset") @console_ns.doc(description="Get all API keys for a dataset")
@api.doc(params={"resource_id": "Dataset ID"}) @console_ns.doc(params={"resource_id": "Dataset ID"})
@api.response(200, "Success", api_key_list) @console_ns.response(200, "Success", api_key_list)
def get(self, resource_id): def get(self, resource_id): # type: ignore
"""Get all API keys for a dataset""" """Get all API keys for a dataset"""
return super().get(resource_id) return super().get(resource_id)
@api.doc("create_dataset_api_key") @console_ns.doc("create_dataset_api_key")
@api.doc(description="Create a new API key for a dataset") @console_ns.doc(description="Create a new API key for a dataset")
@api.doc(params={"resource_id": "Dataset ID"}) @console_ns.doc(params={"resource_id": "Dataset ID"})
@api.response(201, "API key created successfully", api_key_fields) @console_ns.response(201, "API key created successfully", api_key_fields)
@api.response(400, "Maximum keys exceeded") @console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): def post(self, resource_id): # type: ignore
"""Create a new API key for a dataset""" """Create a new API key for a dataset"""
return super().post(resource_id) return super().post(resource_id)
@ -201,10 +198,10 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.route("/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>") @console_ns.route("/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
class DatasetApiKeyResource(BaseApiKeyResource): class DatasetApiKeyResource(BaseApiKeyResource):
@api.doc("delete_dataset_api_key") @console_ns.doc("delete_dataset_api_key")
@api.doc(description="Delete an API key for a dataset") @console_ns.doc(description="Delete an API key for a dataset")
@api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"}) @console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
@api.response(204, "API key deleted successfully") @console_ns.response(204, "API key deleted successfully")
def delete(self, resource_id, api_key_id): def delete(self, resource_id, api_key_id):
"""Delete an API key for a dataset""" """Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id) return super().delete(resource_id, api_key_id)

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService from services.advanced_prompt_template_service import AdvancedPromptTemplateService
@ -16,13 +16,13 @@ parser = (
@console_ns.route("/app/prompt-templates") @console_ns.route("/app/prompt-templates")
class AdvancedPromptTemplateList(Resource): class AdvancedPromptTemplateList(Resource):
@api.doc("get_advanced_prompt_templates") @console_ns.doc("get_advanced_prompt_templates")
@api.doc(description="Get advanced prompt templates based on app mode and model configuration") @console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
@api.expect(parser) @console_ns.expect(parser)
@api.response( @console_ns.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data")) 200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
) )
@api.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.helper import uuid_value from libs.helper import uuid_value
@ -17,12 +17,14 @@ parser = (
@console_ns.route("/apps/<uuid:app_id>/agent/logs") @console_ns.route("/apps/<uuid:app_id>/agent/logs")
class AgentLogApi(Resource): class AgentLogApi(Resource):
@api.doc("get_agent_logs") @console_ns.doc("get_agent_logs")
@api.doc(description="Get agent execution logs for an application") @console_ns.doc(description="Get agent execution logs for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect(parser) @console_ns.expect(parser)
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))) @console_ns.response(
@api.response(400, "Invalid request parameters") 200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
)
@console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -4,7 +4,7 @@ from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
@ -23,11 +23,11 @@ from services.annotation_service import AppAnnotationService
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>") @console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource): class AnnotationReplyActionApi(Resource):
@api.doc("annotation_reply_action") @console_ns.doc("annotation_reply_action")
@api.doc(description="Enable or disable annotation reply for an app") @console_ns.doc(description="Enable or disable annotation reply for an app")
@api.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"}) @console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"AnnotationReplyActionRequest", "AnnotationReplyActionRequest",
{ {
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"), "score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
@ -36,8 +36,8 @@ class AnnotationReplyActionApi(Resource):
}, },
) )
) )
@api.response(200, "Action completed successfully") @console_ns.response(200, "Action completed successfully")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -61,11 +61,11 @@ class AnnotationReplyActionApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotation-setting") @console_ns.route("/apps/<uuid:app_id>/annotation-setting")
class AppAnnotationSettingDetailApi(Resource): class AppAnnotationSettingDetailApi(Resource):
@api.doc("get_annotation_setting") @console_ns.doc("get_annotation_setting")
@api.doc(description="Get annotation settings for an app") @console_ns.doc(description="Get annotation settings for an app")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Annotation settings retrieved successfully") @console_ns.response(200, "Annotation settings retrieved successfully")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -78,11 +78,11 @@ class AppAnnotationSettingDetailApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>") @console_ns.route("/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")
class AppAnnotationSettingUpdateApi(Resource): class AppAnnotationSettingUpdateApi(Resource):
@api.doc("update_annotation_setting") @console_ns.doc("update_annotation_setting")
@api.doc(description="Update annotation settings for an app") @console_ns.doc(description="Update annotation settings for an app")
@api.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"}) @console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"AnnotationSettingUpdateRequest", "AnnotationSettingUpdateRequest",
{ {
"score_threshold": fields.Float(required=True, description="Score threshold"), "score_threshold": fields.Float(required=True, description="Score threshold"),
@ -91,8 +91,8 @@ class AppAnnotationSettingUpdateApi(Resource):
}, },
) )
) )
@api.response(200, "Settings updated successfully") @console_ns.response(200, "Settings updated successfully")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -110,11 +110,11 @@ class AppAnnotationSettingUpdateApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>") @console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>")
class AnnotationReplyActionStatusApi(Resource): class AnnotationReplyActionStatusApi(Resource):
@api.doc("get_annotation_reply_action_status") @console_ns.doc("get_annotation_reply_action_status")
@api.doc(description="Get status of annotation reply action job") @console_ns.doc(description="Get status of annotation reply action job")
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"}) @console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
@api.response(200, "Job status retrieved successfully") @console_ns.response(200, "Job status retrieved successfully")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -138,17 +138,17 @@ class AnnotationReplyActionStatusApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations") @console_ns.route("/apps/<uuid:app_id>/annotations")
class AnnotationApi(Resource): class AnnotationApi(Resource):
@api.doc("list_annotations") @console_ns.doc("list_annotations")
@api.doc(description="Get annotations for an app with pagination") @console_ns.doc(description="Get annotations for an app with pagination")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.parser() console_ns.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number") .add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size") .add_argument("limit", type=int, location="args", default=20, help="Page size")
.add_argument("keyword", type=str, location="args", default="", help="Search keyword") .add_argument("keyword", type=str, location="args", default="", help="Search keyword")
) )
@api.response(200, "Annotations retrieved successfully") @console_ns.response(200, "Annotations retrieved successfully")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -169,11 +169,11 @@ class AnnotationApi(Resource):
} }
return response, 200 return response, 200
@api.doc("create_annotation") @console_ns.doc("create_annotation")
@api.doc(description="Create a new annotation for an app") @console_ns.doc(description="Create a new annotation for an app")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"CreateAnnotationRequest", "CreateAnnotationRequest",
{ {
"message_id": fields.String(description="Message ID (optional)"), "message_id": fields.String(description="Message ID (optional)"),
@ -184,8 +184,8 @@ class AnnotationApi(Resource):
}, },
) )
) )
@api.response(201, "Annotation created successfully", annotation_fields) @console_ns.response(201, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -235,11 +235,11 @@ class AnnotationApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/export") @console_ns.route("/apps/<uuid:app_id>/annotations/export")
class AnnotationExportApi(Resource): class AnnotationExportApi(Resource):
@api.doc("export_annotations") @console_ns.doc("export_annotations")
@api.doc(description="Export all annotations for an app") @console_ns.doc(description="Export all annotations for an app")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Annotations exported successfully", fields.List(fields.Nested(annotation_fields))) @console_ns.response(200, "Annotations exported successfully", fields.List(fields.Nested(annotation_fields)))
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -260,13 +260,13 @@ parser = (
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>") @console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource): class AnnotationUpdateDeleteApi(Resource):
@api.doc("update_delete_annotation") @console_ns.doc("update_delete_annotation")
@api.doc(description="Update or delete an annotation") @console_ns.doc(description="Update or delete an annotation")
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) @console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@api.response(200, "Annotation updated successfully", annotation_fields) @console_ns.response(200, "Annotation updated successfully", annotation_fields)
@api.response(204, "Annotation deleted successfully") @console_ns.response(204, "Annotation deleted successfully")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@api.expect(parser) @console_ns.expect(parser)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -293,12 +293,12 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import") @console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
class AnnotationBatchImportApi(Resource): class AnnotationBatchImportApi(Resource):
@api.doc("batch_import_annotations") @console_ns.doc("batch_import_annotations")
@api.doc(description="Batch import annotations from CSV file") @console_ns.doc(description="Batch import annotations from CSV file")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Batch import started successfully") @console_ns.response(200, "Batch import started successfully")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@api.response(400, "No file uploaded or too many files") @console_ns.response(400, "No file uploaded or too many files")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -323,11 +323,11 @@ class AnnotationBatchImportApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>") @console_ns.route("/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
class AnnotationBatchImportStatusApi(Resource): class AnnotationBatchImportStatusApi(Resource):
@api.doc("get_batch_import_status") @console_ns.doc("get_batch_import_status")
@api.doc(description="Get status of batch import job") @console_ns.doc(description="Get status of batch import job")
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID"}) @console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
@api.response(200, "Job status retrieved successfully") @console_ns.response(200, "Job status retrieved successfully")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -350,18 +350,18 @@ class AnnotationBatchImportStatusApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories") @console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
class AnnotationHitHistoryListApi(Resource): class AnnotationHitHistoryListApi(Resource):
@api.doc("list_annotation_hit_histories") @console_ns.doc("list_annotation_hit_histories")
@api.doc(description="Get hit histories for an annotation") @console_ns.doc(description="Get hit histories for an annotation")
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) @console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@api.expect( @console_ns.expect(
api.parser() console_ns.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number") .add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size") .add_argument("limit", type=int, location="args", default=20, help="Page size")
) )
@api.response( @console_ns.response(
200, "Hit histories retrieved successfully", fields.List(fields.Nested(annotation_hit_history_fields)) 200, "Hit histories retrieved successfully", fields.List(fields.Nested(annotation_hit_history_fields))
) )
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -3,15 +3,16 @@ import uuid
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort from werkzeug.exceptions import BadRequest, abort
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required, edit_permission_required,
enterprise_license_required, enterprise_license_required,
is_admin_or_owner_required,
setup_required, setup_required,
) )
from core.ops.ops_trace_manager import OpsTraceManager from core.ops.ops_trace_manager import OpsTraceManager
@ -31,10 +32,10 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
@console_ns.route("/apps") @console_ns.route("/apps")
class AppListApi(Resource): class AppListApi(Resource):
@api.doc("list_apps") @console_ns.doc("list_apps")
@api.doc(description="Get list of applications with pagination and filtering") @console_ns.doc(description="Get list of applications with pagination and filtering")
@api.expect( @console_ns.expect(
api.parser() console_ns.parser()
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1) .add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20) .add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
.add_argument( .add_argument(
@ -49,7 +50,7 @@ class AppListApi(Resource):
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs") .add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator") .add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
) )
@api.response(200, "Success", app_pagination_fields) @console_ns.response(200, "Success", app_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -138,10 +139,10 @@ class AppListApi(Resource):
return marshal(app_pagination, app_pagination_fields), 200 return marshal(app_pagination, app_pagination_fields), 200
@api.doc("create_app") @console_ns.doc("create_app")
@api.doc(description="Create a new application") @console_ns.doc(description="Create a new application")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"CreateAppRequest", "CreateAppRequest",
{ {
"name": fields.String(required=True, description="App name"), "name": fields.String(required=True, description="App name"),
@ -153,9 +154,9 @@ class AppListApi(Resource):
}, },
) )
) )
@api.response(201, "App created successfully", app_detail_fields) @console_ns.response(201, "App created successfully", app_detail_fields)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@api.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -187,10 +188,10 @@ class AppListApi(Resource):
@console_ns.route("/apps/<uuid:app_id>") @console_ns.route("/apps/<uuid:app_id>")
class AppApi(Resource): class AppApi(Resource):
@api.doc("get_app_detail") @console_ns.doc("get_app_detail")
@api.doc(description="Get application details") @console_ns.doc(description="Get application details")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Success", app_detail_fields_with_site) @console_ns.response(200, "Success", app_detail_fields_with_site)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -209,11 +210,11 @@ class AppApi(Resource):
return app_model return app_model
@api.doc("update_app") @console_ns.doc("update_app")
@api.doc(description="Update application details") @console_ns.doc(description="Update application details")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"UpdateAppRequest", "UpdateAppRequest",
{ {
"name": fields.String(required=True, description="App name"), "name": fields.String(required=True, description="App name"),
@ -226,9 +227,9 @@ class AppApi(Resource):
}, },
) )
) )
@api.response(200, "App updated successfully", app_detail_fields_with_site) @console_ns.response(200, "App updated successfully", app_detail_fields_with_site)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@api.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -250,10 +251,8 @@ class AppApi(Resource):
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
# Construct ArgsDict from parsed arguments
from services.app_service import AppService as AppServiceType
args_dict: AppServiceType.ArgsDict = { args_dict: AppService.ArgsDict = {
"name": args["name"], "name": args["name"],
"description": args.get("description", ""), "description": args.get("description", ""),
"icon_type": args.get("icon_type", ""), "icon_type": args.get("icon_type", ""),
@ -266,11 +265,11 @@ class AppApi(Resource):
return app_model return app_model
@api.doc("delete_app") @console_ns.doc("delete_app")
@api.doc(description="Delete application") @console_ns.doc(description="Delete application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(204, "App deleted successfully") @console_ns.response(204, "App deleted successfully")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -286,11 +285,11 @@ class AppApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/copy") @console_ns.route("/apps/<uuid:app_id>/copy")
class AppCopyApi(Resource): class AppCopyApi(Resource):
@api.doc("copy_app") @console_ns.doc("copy_app")
@api.doc(description="Create a copy of an existing application") @console_ns.doc(description="Create a copy of an existing application")
@api.doc(params={"app_id": "Application ID to copy"}) @console_ns.doc(params={"app_id": "Application ID to copy"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"CopyAppRequest", "CopyAppRequest",
{ {
"name": fields.String(description="Name for the copied app"), "name": fields.String(description="Name for the copied app"),
@ -301,8 +300,8 @@ class AppCopyApi(Resource):
}, },
) )
) )
@api.response(201, "App copied successfully", app_detail_fields_with_site) @console_ns.response(201, "App copied successfully", app_detail_fields_with_site)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -347,20 +346,20 @@ class AppCopyApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/export") @console_ns.route("/apps/<uuid:app_id>/export")
class AppExportApi(Resource): class AppExportApi(Resource):
@api.doc("export_app") @console_ns.doc("export_app")
@api.doc(description="Export application configuration as DSL") @console_ns.doc(description="Export application configuration as DSL")
@api.doc(params={"app_id": "Application ID to export"}) @console_ns.doc(params={"app_id": "Application ID to export"})
@api.expect( @console_ns.expect(
api.parser() console_ns.parser()
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export") .add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export") .add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
) )
@api.response( @console_ns.response(
200, 200,
"App exported successfully", "App exported successfully",
api.model("AppExportResponse", {"data": fields.String(description="DSL export data")}), console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
) )
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -388,11 +387,11 @@ parser = reqparse.RequestParser().add_argument("name", type=str, required=True,
@console_ns.route("/apps/<uuid:app_id>/name") @console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource): class AppNameApi(Resource):
@api.doc("check_app_name") @console_ns.doc("check_app_name")
@api.doc(description="Check if app name is available") @console_ns.doc(description="Check if app name is available")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect(parser) @console_ns.expect(parser)
@api.response(200, "Name availability checked") @console_ns.response(200, "Name availability checked")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -410,11 +409,11 @@ class AppNameApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/icon") @console_ns.route("/apps/<uuid:app_id>/icon")
class AppIconApi(Resource): class AppIconApi(Resource):
@api.doc("update_app_icon") @console_ns.doc("update_app_icon")
@api.doc(description="Update application icon") @console_ns.doc(description="Update application icon")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"AppIconRequest", "AppIconRequest",
{ {
"icon": fields.String(required=True, description="Icon data"), "icon": fields.String(required=True, description="Icon data"),
@ -423,8 +422,8 @@ class AppIconApi(Resource):
}, },
) )
) )
@api.response(200, "Icon updated successfully") @console_ns.response(200, "Icon updated successfully")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -447,16 +446,16 @@ class AppIconApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/site-enable") @console_ns.route("/apps/<uuid:app_id>/site-enable")
class AppSiteStatus(Resource): class AppSiteStatus(Resource):
@api.doc("update_app_site_status") @console_ns.doc("update_app_site_status")
@api.doc(description="Enable or disable app site") @console_ns.doc(description="Enable or disable app site")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")} "AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
) )
) )
@api.response(200, "Site status updated successfully", app_detail_fields) @console_ns.response(200, "Site status updated successfully", app_detail_fields)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -475,27 +474,23 @@ class AppSiteStatus(Resource):
@console_ns.route("/apps/<uuid:app_id>/api-enable") @console_ns.route("/apps/<uuid:app_id>/api-enable")
class AppApiStatus(Resource): class AppApiStatus(Resource):
@api.doc("update_app_api_status") @console_ns.doc("update_app_api_status")
@api.doc(description="Enable or disable app API") @console_ns.doc(description="Enable or disable app API")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")} "AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
) )
) )
@api.response(200, "API status updated successfully", app_detail_fields) @console_ns.response(200, "API status updated successfully", app_detail_fields)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json") parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -507,10 +502,10 @@ class AppApiStatus(Resource):
@console_ns.route("/apps/<uuid:app_id>/trace") @console_ns.route("/apps/<uuid:app_id>/trace")
class AppTraceApi(Resource): class AppTraceApi(Resource):
@api.doc("get_app_trace") @console_ns.doc("get_app_trace")
@api.doc(description="Get app tracing configuration") @console_ns.doc(description="Get app tracing configuration")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Trace configuration retrieved successfully") @console_ns.response(200, "Trace configuration retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -520,11 +515,11 @@ class AppTraceApi(Resource):
return app_trace_config return app_trace_config
@api.doc("update_app_trace") @console_ns.doc("update_app_trace")
@api.doc(description="Update app tracing configuration") @console_ns.doc(description="Update app tracing configuration")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"AppTraceRequest", "AppTraceRequest",
{ {
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"), "enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
@ -532,8 +527,8 @@ class AppTraceApi(Resource):
}, },
) )
) )
@api.response(200, "Trace configuration updated successfully") @console_ns.response(200, "Trace configuration updated successfully")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,7 +1,6 @@
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
@ -35,7 +34,7 @@ parser = (
@console_ns.route("/apps/imports") @console_ns.route("/apps/imports")
class AppImportApi(Resource): class AppImportApi(Resource):
@api.expect(parser) @console_ns.expect(parser)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -5,7 +5,7 @@ from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
AudioTooLargeError, AudioTooLargeError,
@ -36,16 +36,16 @@ logger = logging.getLogger(__name__)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text") @console_ns.route("/apps/<uuid:app_id>/audio-to-text")
class ChatMessageAudioApi(Resource): class ChatMessageAudioApi(Resource):
@api.doc("chat_message_audio_transcript") @console_ns.doc("chat_message_audio_transcript")
@api.doc(description="Transcript audio to text for chat messages") @console_ns.doc(description="Transcript audio to text for chat messages")
@api.doc(params={"app_id": "App ID"}) @console_ns.doc(params={"app_id": "App ID"})
@api.response( @console_ns.response(
200, 200,
"Audio transcription successful", "Audio transcription successful",
api.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}), console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
) )
@api.response(400, "Bad request - No audio uploaded or unsupported type") @console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
@api.response(413, "Audio file too large") @console_ns.response(413, "Audio file too large")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -89,11 +89,11 @@ class ChatMessageAudioApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/text-to-audio") @console_ns.route("/apps/<uuid:app_id>/text-to-audio")
class ChatMessageTextApi(Resource): class ChatMessageTextApi(Resource):
@api.doc("chat_message_text_to_speech") @console_ns.doc("chat_message_text_to_speech")
@api.doc(description="Convert text to speech for chat messages") @console_ns.doc(description="Convert text to speech for chat messages")
@api.doc(params={"app_id": "App ID"}) @console_ns.doc(params={"app_id": "App ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"TextToSpeechRequest", "TextToSpeechRequest",
{ {
"message_id": fields.String(description="Message ID"), "message_id": fields.String(description="Message ID"),
@ -103,8 +103,8 @@ class ChatMessageTextApi(Resource):
}, },
) )
) )
@api.response(200, "Text to speech conversion successful") @console_ns.response(200, "Text to speech conversion successful")
@api.response(400, "Bad request - Invalid parameters") @console_ns.response(400, "Bad request - Invalid parameters")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -156,12 +156,16 @@ class ChatMessageTextApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/text-to-audio/voices") @console_ns.route("/apps/<uuid:app_id>/text-to-audio/voices")
class TextModesApi(Resource): class TextModesApi(Resource):
@api.doc("get_text_to_speech_voices") @console_ns.doc("get_text_to_speech_voices")
@api.doc(description="Get available TTS voices for a specific language") @console_ns.doc(description="Get available TTS voices for a specific language")
@api.doc(params={"app_id": "App ID"}) @console_ns.doc(params={"app_id": "App ID"})
@api.expect(api.parser().add_argument("language", type=str, required=True, location="args", help="Language code")) @console_ns.expect(
@api.response(200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))) console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
@api.response(400, "Invalid language parameter") )
@console_ns.response(
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
)
@console_ns.response(400, "Invalid language parameter")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required

View File

@ -5,7 +5,7 @@ from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
CompletionRequestError, CompletionRequestError,
@ -40,11 +40,11 @@ logger = logging.getLogger(__name__)
# define completion message api for user # define completion message api for user
@console_ns.route("/apps/<uuid:app_id>/completion-messages") @console_ns.route("/apps/<uuid:app_id>/completion-messages")
class CompletionMessageApi(Resource): class CompletionMessageApi(Resource):
@api.doc("create_completion_message") @console_ns.doc("create_completion_message")
@api.doc(description="Generate completion message for debugging") @console_ns.doc(description="Generate completion message for debugging")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"CompletionMessageRequest", "CompletionMessageRequest",
{ {
"inputs": fields.Raw(required=True, description="Input variables"), "inputs": fields.Raw(required=True, description="Input variables"),
@ -56,9 +56,9 @@ class CompletionMessageApi(Resource):
}, },
) )
) )
@api.response(200, "Completion generated successfully") @console_ns.response(200, "Completion generated successfully")
@api.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@api.response(404, "App not found") @console_ns.response(404, "App not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -110,10 +110,10 @@ class CompletionMessageApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop") @console_ns.route("/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
class CompletionMessageStopApi(Resource): class CompletionMessageStopApi(Resource):
@api.doc("stop_completion_message") @console_ns.doc("stop_completion_message")
@api.doc(description="Stop a running completion message generation") @console_ns.doc(description="Stop a running completion message generation")
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"}) @console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
@api.response(200, "Task stopped successfully") @console_ns.response(200, "Task stopped successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -128,11 +128,11 @@ class CompletionMessageStopApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/chat-messages") @console_ns.route("/apps/<uuid:app_id>/chat-messages")
class ChatMessageApi(Resource): class ChatMessageApi(Resource):
@api.doc("create_chat_message") @console_ns.doc("create_chat_message")
@api.doc(description="Generate chat message for debugging") @console_ns.doc(description="Generate chat message for debugging")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"ChatMessageRequest", "ChatMessageRequest",
{ {
"inputs": fields.Raw(required=True, description="Input variables"), "inputs": fields.Raw(required=True, description="Input variables"),
@ -146,9 +146,9 @@ class ChatMessageApi(Resource):
}, },
) )
) )
@api.response(200, "Chat message generated successfully") @console_ns.response(200, "Chat message generated successfully")
@api.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@api.response(404, "App or conversation not found") @console_ns.response(404, "App or conversation not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -209,10 +209,10 @@ class ChatMessageApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop") @console_ns.route("/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
class ChatMessageStopApi(Resource): class ChatMessageStopApi(Resource):
@api.doc("stop_chat_message") @console_ns.doc("stop_chat_message")
@api.doc(description="Stop a running chat message generation") @console_ns.doc(description="Stop a running chat message generation")
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"}) @console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
@api.response(200, "Task stopped successfully") @console_ns.response(200, "Task stopped successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -6,7 +6,7 @@ from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -28,11 +28,11 @@ from services.errors.conversation import ConversationNotExistsError
@console_ns.route("/apps/<uuid:app_id>/completion-conversations") @console_ns.route("/apps/<uuid:app_id>/completion-conversations")
class CompletionConversationApi(Resource): class CompletionConversationApi(Resource):
@api.doc("list_completion_conversations") @console_ns.doc("list_completion_conversations")
@api.doc(description="Get completion conversations with pagination and filtering") @console_ns.doc(description="Get completion conversations with pagination and filtering")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.parser() console_ns.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword") .add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
@ -47,8 +47,8 @@ class CompletionConversationApi(Resource):
.add_argument("page", type=int, location="args", default=1, help="Page number") .add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)") .add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
) )
@api.response(200, "Success", conversation_pagination_fields) @console_ns.response(200, "Success", conversation_pagination_fields)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -122,12 +122,12 @@ class CompletionConversationApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>") @console_ns.route("/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
class CompletionConversationDetailApi(Resource): class CompletionConversationDetailApi(Resource):
@api.doc("get_completion_conversation") @console_ns.doc("get_completion_conversation")
@api.doc(description="Get completion conversation details with messages") @console_ns.doc(description="Get completion conversation details with messages")
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@api.response(200, "Success", conversation_message_detail_fields) @console_ns.response(200, "Success", conversation_message_detail_fields)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@api.response(404, "Conversation not found") @console_ns.response(404, "Conversation not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -139,12 +139,12 @@ class CompletionConversationDetailApi(Resource):
return _get_conversation(app_model, conversation_id) return _get_conversation(app_model, conversation_id)
@api.doc("delete_completion_conversation") @console_ns.doc("delete_completion_conversation")
@api.doc(description="Delete a completion conversation") @console_ns.doc(description="Delete a completion conversation")
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@api.response(204, "Conversation deleted successfully") @console_ns.response(204, "Conversation deleted successfully")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@api.response(404, "Conversation not found") @console_ns.response(404, "Conversation not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -164,11 +164,11 @@ class CompletionConversationDetailApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/chat-conversations") @console_ns.route("/apps/<uuid:app_id>/chat-conversations")
class ChatConversationApi(Resource): class ChatConversationApi(Resource):
@api.doc("list_chat_conversations") @console_ns.doc("list_chat_conversations")
@api.doc(description="Get chat conversations with pagination, filtering and summary") @console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.parser() console_ns.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword") .add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
@ -192,8 +192,8 @@ class ChatConversationApi(Resource):
help="Sort field and direction", help="Sort field and direction",
) )
) )
@api.response(200, "Success", conversation_with_summary_pagination_fields) @console_ns.response(200, "Success", conversation_with_summary_pagination_fields)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -322,12 +322,12 @@ class ChatConversationApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>") @console_ns.route("/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
class ChatConversationDetailApi(Resource): class ChatConversationDetailApi(Resource):
@api.doc("get_chat_conversation") @console_ns.doc("get_chat_conversation")
@api.doc(description="Get chat conversation details") @console_ns.doc(description="Get chat conversation details")
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@api.response(200, "Success", conversation_detail_fields) @console_ns.response(200, "Success", conversation_detail_fields)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@api.response(404, "Conversation not found") @console_ns.response(404, "Conversation not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -339,12 +339,12 @@ class ChatConversationDetailApi(Resource):
return _get_conversation(app_model, conversation_id) return _get_conversation(app_model, conversation_id)
@api.doc("delete_chat_conversation") @console_ns.doc("delete_chat_conversation")
@api.doc(description="Delete a chat conversation") @console_ns.doc(description="Delete a chat conversation")
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@api.response(204, "Conversation deleted successfully") @console_ns.response(204, "Conversation deleted successfully")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@api.response(404, "Conversation not found") @console_ns.response(404, "Conversation not found")
@setup_required @setup_required
@login_required @login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])

View File

@ -2,7 +2,7 @@ from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
@ -14,15 +14,15 @@ from models.model import AppMode
@console_ns.route("/apps/<uuid:app_id>/conversation-variables") @console_ns.route("/apps/<uuid:app_id>/conversation-variables")
class ConversationVariablesApi(Resource): class ConversationVariablesApi(Resource):
@api.doc("get_conversation_variables") @console_ns.doc("get_conversation_variables")
@api.doc(description="Get conversation variables for an application") @console_ns.doc(description="Get conversation variables for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.parser().add_argument( console_ns.parser().add_argument(
"conversation_id", type=str, location="args", help="Conversation ID to filter variables" "conversation_id", type=str, location="args", help="Conversation ID to filter variables"
) )
) )
@api.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_fields) @console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
@ -17,10 +17,10 @@ from services.workflow_service import WorkflowService
@console_ns.route("/rule-generate") @console_ns.route("/rule-generate")
class RuleGenerateApi(Resource): class RuleGenerateApi(Resource):
@api.doc("generate_rule_config") @console_ns.doc("generate_rule_config")
@api.doc(description="Generate rule configuration using LLM") @console_ns.doc(description="Generate rule configuration using LLM")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"RuleGenerateRequest", "RuleGenerateRequest",
{ {
"instruction": fields.String(required=True, description="Rule generation instruction"), "instruction": fields.String(required=True, description="Rule generation instruction"),
@ -29,9 +29,9 @@ class RuleGenerateApi(Resource):
}, },
) )
) )
@api.response(200, "Rule configuration generated successfully") @console_ns.response(200, "Rule configuration generated successfully")
@api.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@api.response(402, "Provider quota exceeded") @console_ns.response(402, "Provider quota exceeded")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -66,10 +66,10 @@ class RuleGenerateApi(Resource):
@console_ns.route("/rule-code-generate") @console_ns.route("/rule-code-generate")
class RuleCodeGenerateApi(Resource): class RuleCodeGenerateApi(Resource):
@api.doc("generate_rule_code") @console_ns.doc("generate_rule_code")
@api.doc(description="Generate code rules using LLM") @console_ns.doc(description="Generate code rules using LLM")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"RuleCodeGenerateRequest", "RuleCodeGenerateRequest",
{ {
"instruction": fields.String(required=True, description="Code generation instruction"), "instruction": fields.String(required=True, description="Code generation instruction"),
@ -81,9 +81,9 @@ class RuleCodeGenerateApi(Resource):
}, },
) )
) )
@api.response(200, "Code rules generated successfully") @console_ns.response(200, "Code rules generated successfully")
@api.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@api.response(402, "Provider quota exceeded") @console_ns.response(402, "Provider quota exceeded")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -119,10 +119,10 @@ class RuleCodeGenerateApi(Resource):
@console_ns.route("/rule-structured-output-generate") @console_ns.route("/rule-structured-output-generate")
class RuleStructuredOutputGenerateApi(Resource): class RuleStructuredOutputGenerateApi(Resource):
@api.doc("generate_structured_output") @console_ns.doc("generate_structured_output")
@api.doc(description="Generate structured output rules using LLM") @console_ns.doc(description="Generate structured output rules using LLM")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"StructuredOutputGenerateRequest", "StructuredOutputGenerateRequest",
{ {
"instruction": fields.String(required=True, description="Structured output generation instruction"), "instruction": fields.String(required=True, description="Structured output generation instruction"),
@ -130,9 +130,9 @@ class RuleStructuredOutputGenerateApi(Resource):
}, },
) )
) )
@api.response(200, "Structured output generated successfully") @console_ns.response(200, "Structured output generated successfully")
@api.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@api.response(402, "Provider quota exceeded") @console_ns.response(402, "Provider quota exceeded")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -165,10 +165,10 @@ class RuleStructuredOutputGenerateApi(Resource):
@console_ns.route("/instruction-generate") @console_ns.route("/instruction-generate")
class InstructionGenerateApi(Resource): class InstructionGenerateApi(Resource):
@api.doc("generate_instruction") @console_ns.doc("generate_instruction")
@api.doc(description="Generate instruction for workflow nodes or general use") @console_ns.doc(description="Generate instruction for workflow nodes or general use")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"InstructionGenerateRequest", "InstructionGenerateRequest",
{ {
"type": fields.String( "type": fields.String(
@ -199,9 +199,9 @@ class InstructionGenerateApi(Resource):
}, },
) )
) )
@api.response(200, "Instruction generated successfully") @console_ns.response(200, "Instruction generated successfully")
@api.response(400, "Invalid request parameters or flow/workflow not found") @console_ns.response(400, "Invalid request parameters or flow/workflow not found")
@api.response(402, "Provider quota exceeded") @console_ns.response(402, "Provider quota exceeded")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -366,10 +366,10 @@ class InstructionGenerateApi(Resource):
@console_ns.route("/instruction-generate/template") @console_ns.route("/instruction-generate/template")
class InstructionGenerationTemplateApi(Resource): class InstructionGenerationTemplateApi(Resource):
@api.doc("get_instruction_template") @console_ns.doc("get_instruction_template")
@api.doc(description="Get instruction generation template") @console_ns.doc(description="Get instruction generation template")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"InstructionTemplateRequest", "InstructionTemplateRequest",
{ {
"instruction": fields.String(required=True, description="Template instruction"), "instruction": fields.String(required=True, description="Template instruction"),
@ -377,8 +377,8 @@ class InstructionGenerationTemplateApi(Resource):
}, },
) )
) )
@api.response(200, "Template retrieved successfully") @console_ns.response(200, "Template retrieved successfully")
@api.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -4,7 +4,7 @@ from enum import StrEnum
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
@ -20,10 +20,10 @@ class AppMCPServerStatus(StrEnum):
@console_ns.route("/apps/<uuid:app_id>/server") @console_ns.route("/apps/<uuid:app_id>/server")
class AppMCPServerController(Resource): class AppMCPServerController(Resource):
@api.doc("get_app_mcp_server") @console_ns.doc("get_app_mcp_server")
@api.doc(description="Get MCP server configuration for an application") @console_ns.doc(description="Get MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields) @console_ns.response(200, "MCP server configuration retrieved successfully", app_server_fields)
@login_required @login_required
@account_initialization_required @account_initialization_required
@setup_required @setup_required
@ -33,11 +33,11 @@ class AppMCPServerController(Resource):
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
return server return server
@api.doc("create_app_mcp_server") @console_ns.doc("create_app_mcp_server")
@api.doc(description="Create MCP server configuration for an application") @console_ns.doc(description="Create MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"MCPServerCreateRequest", "MCPServerCreateRequest",
{ {
"description": fields.String(description="Server description"), "description": fields.String(description="Server description"),
@ -45,8 +45,8 @@ class AppMCPServerController(Resource):
}, },
) )
) )
@api.response(201, "MCP server configuration created successfully", app_server_fields) @console_ns.response(201, "MCP server configuration created successfully", app_server_fields)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@login_required @login_required
@ -79,11 +79,11 @@ class AppMCPServerController(Resource):
db.session.commit() db.session.commit()
return server return server
@api.doc("update_app_mcp_server") @console_ns.doc("update_app_mcp_server")
@api.doc(description="Update MCP server configuration for an application") @console_ns.doc(description="Update MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"MCPServerUpdateRequest", "MCPServerUpdateRequest",
{ {
"id": fields.String(required=True, description="Server ID"), "id": fields.String(required=True, description="Server ID"),
@ -93,9 +93,9 @@ class AppMCPServerController(Resource):
}, },
) )
) )
@api.response(200, "MCP server configuration updated successfully", app_server_fields) @console_ns.response(200, "MCP server configuration updated successfully", app_server_fields)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@api.response(404, "Server not found") @console_ns.response(404, "Server not found")
@get_app_model @get_app_model
@login_required @login_required
@setup_required @setup_required
@ -134,12 +134,12 @@ class AppMCPServerController(Resource):
@console_ns.route("/apps/<uuid:server_id>/server/refresh") @console_ns.route("/apps/<uuid:server_id>/server/refresh")
class AppMCPServerRefreshController(Resource): class AppMCPServerRefreshController(Resource):
@api.doc("refresh_app_mcp_server") @console_ns.doc("refresh_app_mcp_server")
@api.doc(description="Refresh MCP server configuration and regenerate server code") @console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
@api.doc(params={"server_id": "Server ID"}) @console_ns.doc(params={"server_id": "Server ID"})
@api.response(200, "MCP server refreshed successfully", app_server_fields) @console_ns.response(200, "MCP server refreshed successfully", app_server_fields)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@api.response(404, "Server not found") @console_ns.response(404, "Server not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -5,7 +5,7 @@ from flask_restx.inputs import int_range
from sqlalchemy import exists, select from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
@ -43,17 +43,17 @@ class ChatMessageListApi(Resource):
"data": fields.List(fields.Nested(message_detail_fields)), "data": fields.List(fields.Nested(message_detail_fields)),
} }
@api.doc("list_chat_messages") @console_ns.doc("list_chat_messages")
@api.doc(description="Get chat messages for a conversation with pagination") @console_ns.doc(description="Get chat messages for a conversation with pagination")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.parser() console_ns.parser()
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID") .add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
.add_argument("first_id", type=str, location="args", help="First message ID for pagination") .add_argument("first_id", type=str, location="args", help="First message ID for pagination")
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)") .add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
) )
@api.response(200, "Success", message_infinite_scroll_pagination_fields) @console_ns.response(200, "Success", message_infinite_scroll_pagination_fields)
@api.response(404, "Conversation not found") @console_ns.response(404, "Conversation not found")
@login_required @login_required
@account_initialization_required @account_initialization_required
@setup_required @setup_required
@ -132,11 +132,11 @@ class ChatMessageListApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/feedbacks") @console_ns.route("/apps/<uuid:app_id>/feedbacks")
class MessageFeedbackApi(Resource): class MessageFeedbackApi(Resource):
@api.doc("create_message_feedback") @console_ns.doc("create_message_feedback")
@api.doc(description="Create or update message feedback (like/dislike)") @console_ns.doc(description="Create or update message feedback (like/dislike)")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"MessageFeedbackRequest", "MessageFeedbackRequest",
{ {
"message_id": fields.String(required=True, description="Message ID"), "message_id": fields.String(required=True, description="Message ID"),
@ -144,9 +144,9 @@ class MessageFeedbackApi(Resource):
}, },
) )
) )
@api.response(200, "Feedback updated successfully") @console_ns.response(200, "Feedback updated successfully")
@api.response(404, "Message not found") @console_ns.response(404, "Message not found")
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -194,13 +194,13 @@ class MessageFeedbackApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/count") @console_ns.route("/apps/<uuid:app_id>/annotations/count")
class MessageAnnotationCountApi(Resource): class MessageAnnotationCountApi(Resource):
@api.doc("get_annotation_count") @console_ns.doc("get_annotation_count")
@api.doc(description="Get count of message annotations for the app") @console_ns.doc(description="Get count of message annotations for the app")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response( @console_ns.response(
200, 200,
"Annotation count retrieved successfully", "Annotation count retrieved successfully",
api.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}), console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
) )
@get_app_model @get_app_model
@setup_required @setup_required
@ -214,15 +214,17 @@ class MessageAnnotationCountApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions") @console_ns.route("/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
class MessageSuggestedQuestionApi(Resource): class MessageSuggestedQuestionApi(Resource):
@api.doc("get_message_suggested_questions") @console_ns.doc("get_message_suggested_questions")
@api.doc(description="Get suggested questions for a message") @console_ns.doc(description="Get suggested questions for a message")
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) @console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@api.response( @console_ns.response(
200, 200,
"Suggested questions retrieved successfully", "Suggested questions retrieved successfully",
api.model("SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}), console_ns.model(
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
),
) )
@api.response(404, "Message or conversation not found") @console_ns.response(404, "Message or conversation not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -258,11 +260,11 @@ class MessageSuggestedQuestionApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>") @console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>")
class MessageApi(Resource): class MessageApi(Resource):
@api.doc("get_message") @console_ns.doc("get_message")
@api.doc(description="Get message details by ID") @console_ns.doc(description="Get message details by ID")
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) @console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@api.response(200, "Message retrieved successfully", message_detail_fields) @console_ns.response(200, "Message retrieved successfully", message_detail_fields)
@api.response(404, "Message not found") @console_ns.response(404, "Message not found")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required

View File

@ -3,11 +3,10 @@ from typing import cast
from flask import request from flask import request
from flask_restx import Resource, fields from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.configuration import ToolParameterConfigurationManager
@ -21,11 +20,11 @@ from services.app_model_config_service import AppModelConfigService
@console_ns.route("/apps/<uuid:app_id>/model-config") @console_ns.route("/apps/<uuid:app_id>/model-config")
class ModelConfigResource(Resource): class ModelConfigResource(Resource):
@api.doc("update_app_model_config") @console_ns.doc("update_app_model_config")
@api.doc(description="Update application model configuration") @console_ns.doc(description="Update application model configuration")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"ModelConfigRequest", "ModelConfigRequest",
{ {
"provider": fields.String(description="Model provider"), "provider": fields.String(description="Model provider"),
@ -43,20 +42,17 @@ class ModelConfigResource(Resource):
}, },
) )
) )
@api.response(200, "Model configuration updated successfully") @console_ns.response(200, "Model configuration updated successfully")
@api.response(400, "Invalid configuration") @console_ns.response(400, "Invalid configuration")
@api.response(404, "App not found") @console_ns.response(404, "App not found")
@setup_required @setup_required
@login_required @login_required
@edit_permission_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model): def post(self, app_model):
"""Modify app model config""" """Modify app model config"""
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# validate config # validate config
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required from libs.login import login_required
@ -14,18 +14,18 @@ class TraceAppConfigApi(Resource):
Manage trace app configurations Manage trace app configurations
""" """
@api.doc("get_trace_app_config") @console_ns.doc("get_trace_app_config")
@api.doc(description="Get tracing configuration for an application") @console_ns.doc(description="Get tracing configuration for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.parser().add_argument( console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name" "tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
) )
) )
@api.response( @console_ns.response(
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data") 200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
) )
@api.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -41,11 +41,11 @@ class TraceAppConfigApi(Resource):
except Exception as e: except Exception as e:
raise BadRequest(str(e)) raise BadRequest(str(e))
@api.doc("create_trace_app_config") @console_ns.doc("create_trace_app_config")
@api.doc(description="Create a new tracing configuration for an application") @console_ns.doc(description="Create a new tracing configuration for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"TraceConfigCreateRequest", "TraceConfigCreateRequest",
{ {
"tracing_provider": fields.String(required=True, description="Tracing provider name"), "tracing_provider": fields.String(required=True, description="Tracing provider name"),
@ -53,10 +53,10 @@ class TraceAppConfigApi(Resource):
}, },
) )
) )
@api.response( @console_ns.response(
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data") 201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
) )
@api.response(400, "Invalid request parameters or configuration already exists") @console_ns.response(400, "Invalid request parameters or configuration already exists")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -81,11 +81,11 @@ class TraceAppConfigApi(Resource):
except Exception as e: except Exception as e:
raise BadRequest(str(e)) raise BadRequest(str(e))
@api.doc("update_trace_app_config") @console_ns.doc("update_trace_app_config")
@api.doc(description="Update an existing tracing configuration for an application") @console_ns.doc(description="Update an existing tracing configuration for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"TraceConfigUpdateRequest", "TraceConfigUpdateRequest",
{ {
"tracing_provider": fields.String(required=True, description="Tracing provider name"), "tracing_provider": fields.String(required=True, description="Tracing provider name"),
@ -93,8 +93,8 @@ class TraceAppConfigApi(Resource):
}, },
) )
) )
@api.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response")) @console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
@api.response(400, "Invalid request parameters or configuration not found") @console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -117,16 +117,16 @@ class TraceAppConfigApi(Resource):
except Exception as e: except Exception as e:
raise BadRequest(str(e)) raise BadRequest(str(e))
@api.doc("delete_trace_app_config") @console_ns.doc("delete_trace_app_config")
@api.doc(description="Delete an existing tracing configuration for an application") @console_ns.doc(description="Delete an existing tracing configuration for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.parser().add_argument( console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name" "tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
) )
) )
@api.response(204, "Tracing configuration deleted successfully") @console_ns.response(204, "Tracing configuration deleted successfully")
@api.response(400, "Invalid request parameters or configuration not found") @console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,10 +1,15 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import NotFound
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
is_admin_or_owner_required,
setup_required,
)
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_site_fields from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
@ -43,11 +48,11 @@ def parse_app_site_args():
@console_ns.route("/apps/<uuid:app_id>/site") @console_ns.route("/apps/<uuid:app_id>/site")
class AppSite(Resource): class AppSite(Resource):
@api.doc("update_app_site") @console_ns.doc("update_app_site")
@api.doc(description="Update application site configuration") @console_ns.doc(description="Update application site configuration")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"AppSiteRequest", "AppSiteRequest",
{ {
"title": fields.String(description="Site title"), "title": fields.String(description="Site title"),
@ -71,22 +76,18 @@ class AppSite(Resource):
}, },
) )
) )
@api.response(200, "Site configuration updated successfully", app_site_fields) @console_ns.response(200, "Site configuration updated successfully", app_site_fields)
@api.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@api.response(404, "App not found") @console_ns.response(404, "App not found")
@setup_required @setup_required
@login_required @login_required
@edit_permission_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_site_fields) @marshal_with(app_site_fields)
def post(self, app_model): def post(self, app_model):
args = parse_app_site_args() args = parse_app_site_args()
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be editor, admin, or owner
if not current_user.has_edit_permission:
raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site: if not site:
raise NotFound raise NotFound
@ -122,24 +123,20 @@ class AppSite(Resource):
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset") @console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
class AppSiteAccessTokenReset(Resource): class AppSiteAccessTokenReset(Resource):
@api.doc("reset_app_site_access_token") @console_ns.doc("reset_app_site_access_token")
@api.doc(description="Reset access token for application site") @console_ns.doc(description="Reset access token for application site")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Access token reset successfully", app_site_fields) @console_ns.response(200, "Access token reset successfully", app_site_fields)
@api.response(403, "Insufficient permissions (admin/owner required)") @console_ns.response(403, "Insufficient permissions (admin/owner required)")
@api.response(404, "App or site not found") @console_ns.response(404, "App or site not found")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_site_fields) @marshal_with(app_site_fields)
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site: if not site:

View File

@ -4,28 +4,28 @@ import sqlalchemy as sa
from flask import abort, jsonify from flask import abort, jsonify
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import parse_time_range from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString from libs.helper import DatetimeString, convert_datetime_to_date
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import AppMode, Message from models import AppMode
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages") @console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
class DailyMessageStatistic(Resource): class DailyMessageStatistic(Resource):
@api.doc("get_daily_message_statistics") @console_ns.doc("get_daily_message_statistics")
@api.doc(description="Get daily message statistics for an application") @console_ns.doc(description="Get daily message statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.parser() console_ns.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)") .add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)") .add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
) )
@api.response( @console_ns.response(
200, 200,
"Daily message statistics retrieved successfully", "Daily message statistics retrieved successfully",
fields.List(fields.Raw(description="Daily message count data")), fields.List(fields.Raw(description="Daily message count data")),
@ -44,8 +44,9 @@ class DailyMessageStatistic(Resource):
) )
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT converted_created_at = convert_datetime_to_date("created_at")
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(*) AS message_count COUNT(*) AS message_count
FROM FROM
messages messages
@ -89,11 +90,11 @@ parser = (
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations") @console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
class DailyConversationStatistic(Resource): class DailyConversationStatistic(Resource):
@api.doc("get_daily_conversation_statistics") @console_ns.doc("get_daily_conversation_statistics")
@api.doc(description="Get daily conversation statistics for an application") @console_ns.doc(description="Get daily conversation statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect(parser) @console_ns.expect(parser)
@api.response( @console_ns.response(
200, 200,
"Daily conversation statistics retrieved successfully", "Daily conversation statistics retrieved successfully",
fields.List(fields.Raw(description="Daily conversation count data")), fields.List(fields.Raw(description="Daily conversation count data")),
@ -106,6 +107,17 @@ class DailyConversationStatistic(Resource):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(DISTINCT conversation_id) AS conversation_count
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None assert account.timezone is not None
try: try:
@ -113,41 +125,32 @@ class DailyConversationStatistic(Resource):
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
stmt = (
sa.select(
sa.func.date(
sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
).label("date"),
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
)
.select_from(Message)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
)
if start_datetime_utc: if start_datetime_utc:
stmt = stmt.where(Message.created_at >= start_datetime_utc) sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if end_datetime_utc: if end_datetime_utc:
stmt = stmt.where(Message.created_at < end_datetime_utc) sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
stmt = stmt.group_by("date").order_by("date") sql_query += " GROUP BY date ORDER BY date"
response_data = [] response_data = []
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(stmt, {"tz": account.timezone}) rs = conn.execute(sa.text(sql_query), arg_dict)
for row in rs: for i in rs:
response_data.append({"date": str(row.date), "conversation_count": row.conversation_count}) response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
return jsonify({"data": response_data}) return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-end-users") @console_ns.route("/apps/<uuid:app_id>/statistics/daily-end-users")
class DailyTerminalsStatistic(Resource): class DailyTerminalsStatistic(Resource):
@api.doc("get_daily_terminals_statistics") @console_ns.doc("get_daily_terminals_statistics")
@api.doc(description="Get daily terminal/end-user statistics for an application") @console_ns.doc(description="Get daily terminal/end-user statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect(parser) @console_ns.expect(parser)
@api.response( @console_ns.response(
200, 200,
"Daily terminal statistics retrieved successfully", "Daily terminal statistics retrieved successfully",
fields.List(fields.Raw(description="Daily terminal count data")), fields.List(fields.Raw(description="Daily terminal count data")),
@ -161,8 +164,9 @@ class DailyTerminalsStatistic(Resource):
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT converted_created_at = convert_datetime_to_date("created_at")
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
FROM FROM
messages messages
@ -199,11 +203,11 @@ WHERE
@console_ns.route("/apps/<uuid:app_id>/statistics/token-costs") @console_ns.route("/apps/<uuid:app_id>/statistics/token-costs")
class DailyTokenCostStatistic(Resource): class DailyTokenCostStatistic(Resource):
@api.doc("get_daily_token_cost_statistics") @console_ns.doc("get_daily_token_cost_statistics")
@api.doc(description="Get daily token cost statistics for an application") @console_ns.doc(description="Get daily token cost statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect(parser) @console_ns.expect(parser)
@api.response( @console_ns.response(
200, 200,
"Daily token cost statistics retrieved successfully", "Daily token cost statistics retrieved successfully",
fields.List(fields.Raw(description="Daily token cost data")), fields.List(fields.Raw(description="Daily token cost data")),
@ -217,8 +221,9 @@ class DailyTokenCostStatistic(Resource):
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT converted_created_at = convert_datetime_to_date("created_at")
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, sql_query = f"""SELECT
{converted_created_at} AS date,
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count, (SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
SUM(total_price) AS total_price SUM(total_price) AS total_price
FROM FROM
@ -258,11 +263,11 @@ WHERE
@console_ns.route("/apps/<uuid:app_id>/statistics/average-session-interactions") @console_ns.route("/apps/<uuid:app_id>/statistics/average-session-interactions")
class AverageSessionInteractionStatistic(Resource): class AverageSessionInteractionStatistic(Resource):
@api.doc("get_average_session_interaction_statistics") @console_ns.doc("get_average_session_interaction_statistics")
@api.doc(description="Get average session interaction statistics for an application") @console_ns.doc(description="Get average session interaction statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect(parser) @console_ns.expect(parser)
@api.response( @console_ns.response(
200, 200,
"Average session interaction statistics retrieved successfully", "Average session interaction statistics retrieved successfully",
fields.List(fields.Raw(description="Average session interaction data")), fields.List(fields.Raw(description="Average session interaction data")),
@ -276,8 +281,9 @@ class AverageSessionInteractionStatistic(Resource):
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT converted_created_at = convert_datetime_to_date("c.created_at")
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, sql_query = f"""SELECT
{converted_created_at} AS date,
AVG(subquery.message_count) AS interactions AVG(subquery.message_count) AS interactions
FROM FROM
( (
@ -333,11 +339,11 @@ ORDER BY
@console_ns.route("/apps/<uuid:app_id>/statistics/user-satisfaction-rate") @console_ns.route("/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
class UserSatisfactionRateStatistic(Resource): class UserSatisfactionRateStatistic(Resource):
@api.doc("get_user_satisfaction_rate_statistics") @console_ns.doc("get_user_satisfaction_rate_statistics")
@api.doc(description="Get user satisfaction rate statistics for an application") @console_ns.doc(description="Get user satisfaction rate statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect(parser) @console_ns.expect(parser)
@api.response( @console_ns.response(
200, 200,
"User satisfaction rate statistics retrieved successfully", "User satisfaction rate statistics retrieved successfully",
fields.List(fields.Raw(description="User satisfaction rate data")), fields.List(fields.Raw(description="User satisfaction rate data")),
@ -351,8 +357,9 @@ class UserSatisfactionRateStatistic(Resource):
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT converted_created_at = convert_datetime_to_date("m.created_at")
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(m.id) AS message_count, COUNT(m.id) AS message_count,
COUNT(mf.id) AS feedback_count COUNT(mf.id) AS feedback_count
FROM FROM
@ -398,11 +405,11 @@ WHERE
@console_ns.route("/apps/<uuid:app_id>/statistics/average-response-time") @console_ns.route("/apps/<uuid:app_id>/statistics/average-response-time")
class AverageResponseTimeStatistic(Resource): class AverageResponseTimeStatistic(Resource):
@api.doc("get_average_response_time_statistics") @console_ns.doc("get_average_response_time_statistics")
@api.doc(description="Get average response time statistics for an application") @console_ns.doc(description="Get average response time statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect(parser) @console_ns.expect(parser)
@api.response( @console_ns.response(
200, 200,
"Average response time statistics retrieved successfully", "Average response time statistics retrieved successfully",
fields.List(fields.Raw(description="Average response time data")), fields.List(fields.Raw(description="Average response time data")),
@ -416,8 +423,9 @@ class AverageResponseTimeStatistic(Resource):
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT converted_created_at = convert_datetime_to_date("created_at")
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, sql_query = f"""SELECT
{converted_created_at} AS date,
AVG(provider_response_latency) AS latency AVG(provider_response_latency) AS latency
FROM FROM
messages messages
@ -454,11 +462,11 @@ WHERE
@console_ns.route("/apps/<uuid:app_id>/statistics/tokens-per-second") @console_ns.route("/apps/<uuid:app_id>/statistics/tokens-per-second")
class TokensPerSecondStatistic(Resource): class TokensPerSecondStatistic(Resource):
@api.doc("get_tokens_per_second_statistics") @console_ns.doc("get_tokens_per_second_statistics")
@api.doc(description="Get tokens per second statistics for an application") @console_ns.doc(description="Get tokens per second statistics for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect(parser) @console_ns.expect(parser)
@api.response( @console_ns.response(
200, 200,
"Tokens per second statistics retrieved successfully", "Tokens per second statistics retrieved successfully",
fields.List(fields.Raw(description="Tokens per second data")), fields.List(fields.Raw(description="Tokens per second data")),
@ -471,8 +479,9 @@ class TokensPerSecondStatistic(Resource):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT converted_created_at = convert_datetime_to_date("created_at")
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, sql_query = f"""SELECT
{converted_created_at} AS date,
CASE CASE
WHEN SUM(provider_response_latency) = 0 THEN 0 WHEN SUM(provider_response_latency) = 0 THEN 0
ELSE (SUM(answer_tokens) / SUM(provider_response_latency)) ELSE (SUM(answer_tokens) / SUM(provider_response_latency))

View File

@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@ -71,11 +71,11 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence
@console_ns.route("/apps/<uuid:app_id>/workflows/draft") @console_ns.route("/apps/<uuid:app_id>/workflows/draft")
class DraftWorkflowApi(Resource): class DraftWorkflowApi(Resource):
@api.doc("get_draft_workflow") @console_ns.doc("get_draft_workflow")
@api.doc(description="Get draft workflow for an application") @console_ns.doc(description="Get draft workflow for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Draft workflow retrieved successfully", workflow_fields) @console_ns.response(200, "Draft workflow retrieved successfully", workflow_fields)
@api.response(404, "Draft workflow not found") @console_ns.response(404, "Draft workflow not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -100,10 +100,10 @@ class DraftWorkflowApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@api.doc("sync_draft_workflow") @console_ns.doc("sync_draft_workflow")
@api.doc(description="Sync draft workflow configuration") @console_ns.doc(description="Sync draft workflow configuration")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"SyncDraftWorkflowRequest", "SyncDraftWorkflowRequest",
{ {
"graph": fields.Raw(required=True, description="Workflow graph configuration"), "graph": fields.Raw(required=True, description="Workflow graph configuration"),
@ -115,10 +115,10 @@ class DraftWorkflowApi(Resource):
}, },
) )
) )
@api.response( @console_ns.response(
200, 200,
"Draft workflow synced successfully", "Draft workflow synced successfully",
api.model( console_ns.model(
"SyncDraftWorkflowResponse", "SyncDraftWorkflowResponse",
{ {
"result": fields.String, "result": fields.String,
@ -127,8 +127,8 @@ class DraftWorkflowApi(Resource):
}, },
), ),
) )
@api.response(400, "Invalid workflow configuration") @console_ns.response(400, "Invalid workflow configuration")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@edit_permission_required @edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
@ -210,11 +210,11 @@ class DraftWorkflowApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/run") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
class AdvancedChatDraftWorkflowRunApi(Resource): class AdvancedChatDraftWorkflowRunApi(Resource):
@api.doc("run_advanced_chat_draft_workflow") @console_ns.doc("run_advanced_chat_draft_workflow")
@api.doc(description="Run draft workflow for advanced chat application") @console_ns.doc(description="Run draft workflow for advanced chat application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"AdvancedChatWorkflowRunRequest", "AdvancedChatWorkflowRunRequest",
{ {
"query": fields.String(required=True, description="User query"), "query": fields.String(required=True, description="User query"),
@ -224,9 +224,9 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
}, },
) )
) )
@api.response(200, "Workflow run started successfully") @console_ns.response(200, "Workflow run started successfully")
@api.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -274,11 +274,11 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run")
class AdvancedChatDraftRunIterationNodeApi(Resource): class AdvancedChatDraftRunIterationNodeApi(Resource):
@api.doc("run_advanced_chat_draft_iteration_node") @console_ns.doc("run_advanced_chat_draft_iteration_node")
@api.doc(description="Run draft workflow iteration node for advanced chat") @console_ns.doc(description="Run draft workflow iteration node for advanced chat")
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"IterationNodeRunRequest", "IterationNodeRunRequest",
{ {
"task_id": fields.String(required=True, description="Task ID"), "task_id": fields.String(required=True, description="Task ID"),
@ -286,9 +286,9 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
}, },
) )
) )
@api.response(200, "Iteration node run started successfully") @console_ns.response(200, "Iteration node run started successfully")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@api.response(404, "Node not found") @console_ns.response(404, "Node not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -321,11 +321,11 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class WorkflowDraftRunIterationNodeApi(Resource): class WorkflowDraftRunIterationNodeApi(Resource):
@api.doc("run_workflow_draft_iteration_node") @console_ns.doc("run_workflow_draft_iteration_node")
@api.doc(description="Run draft workflow iteration node") @console_ns.doc(description="Run draft workflow iteration node")
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"WorkflowIterationNodeRunRequest", "WorkflowIterationNodeRunRequest",
{ {
"task_id": fields.String(required=True, description="Task ID"), "task_id": fields.String(required=True, description="Task ID"),
@ -333,9 +333,9 @@ class WorkflowDraftRunIterationNodeApi(Resource):
}, },
) )
) )
@api.response(200, "Workflow iteration node run started successfully") @console_ns.response(200, "Workflow iteration node run started successfully")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@api.response(404, "Node not found") @console_ns.response(404, "Node not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -368,11 +368,11 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run")
class AdvancedChatDraftRunLoopNodeApi(Resource): class AdvancedChatDraftRunLoopNodeApi(Resource):
@api.doc("run_advanced_chat_draft_loop_node") @console_ns.doc("run_advanced_chat_draft_loop_node")
@api.doc(description="Run draft workflow loop node for advanced chat") @console_ns.doc(description="Run draft workflow loop node for advanced chat")
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"LoopNodeRunRequest", "LoopNodeRunRequest",
{ {
"task_id": fields.String(required=True, description="Task ID"), "task_id": fields.String(required=True, description="Task ID"),
@ -380,9 +380,9 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
}, },
) )
) )
@api.response(200, "Loop node run started successfully") @console_ns.response(200, "Loop node run started successfully")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@api.response(404, "Node not found") @console_ns.response(404, "Node not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -415,11 +415,11 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class WorkflowDraftRunLoopNodeApi(Resource): class WorkflowDraftRunLoopNodeApi(Resource):
@api.doc("run_workflow_draft_loop_node") @console_ns.doc("run_workflow_draft_loop_node")
@api.doc(description="Run draft workflow loop node") @console_ns.doc(description="Run draft workflow loop node")
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"WorkflowLoopNodeRunRequest", "WorkflowLoopNodeRunRequest",
{ {
"task_id": fields.String(required=True, description="Task ID"), "task_id": fields.String(required=True, description="Task ID"),
@ -427,9 +427,9 @@ class WorkflowDraftRunLoopNodeApi(Resource):
}, },
) )
) )
@api.response(200, "Workflow loop node run started successfully") @console_ns.response(200, "Workflow loop node run started successfully")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@api.response(404, "Node not found") @console_ns.response(404, "Node not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -462,11 +462,11 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/run") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/run")
class DraftWorkflowRunApi(Resource): class DraftWorkflowRunApi(Resource):
@api.doc("run_draft_workflow") @console_ns.doc("run_draft_workflow")
@api.doc(description="Run draft workflow") @console_ns.doc(description="Run draft workflow")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"DraftWorkflowRunRequest", "DraftWorkflowRunRequest",
{ {
"inputs": fields.Raw(required=True, description="Input variables"), "inputs": fields.Raw(required=True, description="Input variables"),
@ -474,8 +474,8 @@ class DraftWorkflowRunApi(Resource):
}, },
) )
) )
@api.response(200, "Draft workflow run started successfully") @console_ns.response(200, "Draft workflow run started successfully")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -513,12 +513,12 @@ class DraftWorkflowRunApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop") @console_ns.route("/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
class WorkflowTaskStopApi(Resource): class WorkflowTaskStopApi(Resource):
@api.doc("stop_workflow_task") @console_ns.doc("stop_workflow_task")
@api.doc(description="Stop running workflow task") @console_ns.doc(description="Stop running workflow task")
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID"}) @console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID"})
@api.response(200, "Task stopped successfully") @console_ns.response(200, "Task stopped successfully")
@api.response(404, "Task not found") @console_ns.response(404, "Task not found")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -540,20 +540,20 @@ class WorkflowTaskStopApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
class DraftWorkflowNodeRunApi(Resource): class DraftWorkflowNodeRunApi(Resource):
@api.doc("run_draft_workflow_node") @console_ns.doc("run_draft_workflow_node")
@api.doc(description="Run draft workflow node") @console_ns.doc(description="Run draft workflow node")
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"DraftWorkflowNodeRunRequest", "DraftWorkflowNodeRunRequest",
{ {
"inputs": fields.Raw(description="Input variables"), "inputs": fields.Raw(description="Input variables"),
}, },
) )
) )
@api.response(200, "Node run started successfully", workflow_run_node_execution_fields) @console_ns.response(200, "Node run started successfully", workflow_run_node_execution_fields)
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@api.response(404, "Node not found") @console_ns.response(404, "Node not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -607,11 +607,11 @@ parser_publish = (
@console_ns.route("/apps/<uuid:app_id>/workflows/publish") @console_ns.route("/apps/<uuid:app_id>/workflows/publish")
class PublishedWorkflowApi(Resource): class PublishedWorkflowApi(Resource):
@api.doc("get_published_workflow") @console_ns.doc("get_published_workflow")
@api.doc(description="Get published workflow for an application") @console_ns.doc(description="Get published workflow for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Published workflow retrieved successfully", workflow_fields) @console_ns.response(200, "Published workflow retrieved successfully", workflow_fields)
@api.response(404, "Published workflow not found") @console_ns.response(404, "Published workflow not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -629,7 +629,7 @@ class PublishedWorkflowApi(Resource):
# return workflow, if not found, return None # return workflow, if not found, return None
return workflow return workflow
@api.expect(parser_publish) @console_ns.expect(parser_publish)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -678,10 +678,10 @@ class PublishedWorkflowApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs") @console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
class DefaultBlockConfigsApi(Resource): class DefaultBlockConfigsApi(Resource):
@api.doc("get_default_block_configs") @console_ns.doc("get_default_block_configs")
@api.doc(description="Get default block configurations for workflow") @console_ns.doc(description="Get default block configurations for workflow")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Default block configurations retrieved successfully") @console_ns.response(200, "Default block configurations retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -701,12 +701,12 @@ parser_block = reqparse.RequestParser().add_argument("q", type=str, location="ar
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>") @console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultBlockConfigApi(Resource): class DefaultBlockConfigApi(Resource):
@api.doc("get_default_block_config") @console_ns.doc("get_default_block_config")
@api.doc(description="Get default block configuration by type") @console_ns.doc(description="Get default block configuration by type")
@api.doc(params={"app_id": "Application ID", "block_type": "Block type"}) @console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
@api.response(200, "Default block configuration retrieved successfully") @console_ns.response(200, "Default block configuration retrieved successfully")
@api.response(404, "Block type not found") @console_ns.response(404, "Block type not found")
@api.expect(parser_block) @console_ns.expect(parser_block)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -743,13 +743,13 @@ parser_convert = (
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow") @console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
class ConvertToWorkflowApi(Resource): class ConvertToWorkflowApi(Resource):
@api.expect(parser_convert) @console_ns.expect(parser_convert)
@api.doc("convert_to_workflow") @console_ns.doc("convert_to_workflow")
@api.doc(description="Convert application to workflow mode") @console_ns.doc(description="Convert application to workflow mode")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Application converted to workflow successfully") @console_ns.response(200, "Application converted to workflow successfully")
@api.response(400, "Application cannot be converted") @console_ns.response(400, "Application cannot be converted")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -789,11 +789,11 @@ parser_workflows = (
@console_ns.route("/apps/<uuid:app_id>/workflows") @console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource): class PublishedAllWorkflowApi(Resource):
@api.expect(parser_workflows) @console_ns.expect(parser_workflows)
@api.doc("get_all_published_workflows") @console_ns.doc("get_all_published_workflows")
@api.doc(description="Get all published workflows for an application") @console_ns.doc(description="Get all published workflows for an application")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Published workflows retrieved successfully", workflow_pagination_fields) @console_ns.response(200, "Published workflows retrieved successfully", workflow_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -838,11 +838,11 @@ class PublishedAllWorkflowApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>") @console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
class WorkflowByIdApi(Resource): class WorkflowByIdApi(Resource):
@api.doc("update_workflow_by_id") @console_ns.doc("update_workflow_by_id")
@api.doc(description="Update workflow by ID") @console_ns.doc(description="Update workflow by ID")
@api.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"}) @console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"UpdateWorkflowRequest", "UpdateWorkflowRequest",
{ {
"environment_variables": fields.List(fields.Raw, description="Environment variables"), "environment_variables": fields.List(fields.Raw, description="Environment variables"),
@ -850,9 +850,9 @@ class WorkflowByIdApi(Resource):
}, },
) )
) )
@api.response(200, "Workflow updated successfully", workflow_fields) @console_ns.response(200, "Workflow updated successfully", workflow_fields)
@api.response(404, "Workflow not found") @console_ns.response(404, "Workflow not found")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -938,12 +938,12 @@ class WorkflowByIdApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run")
class DraftWorkflowNodeLastRunApi(Resource): class DraftWorkflowNodeLastRunApi(Resource):
@api.doc("get_draft_workflow_node_last_run") @console_ns.doc("get_draft_workflow_node_last_run")
@api.doc(description="Get last run result for draft workflow node") @console_ns.doc(description="Get last run result for draft workflow node")
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@api.response(200, "Node last run retrieved successfully", workflow_run_node_execution_fields) @console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_fields)
@api.response(404, "Node last run not found") @console_ns.response(404, "Node last run not found")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -971,20 +971,20 @@ class DraftWorkflowTriggerRunApi(Resource):
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
""" """
@api.doc("poll_draft_workflow_trigger_run") @console_ns.doc("poll_draft_workflow_trigger_run")
@api.doc(description="Poll for trigger events and execute full workflow when event arrives") @console_ns.doc(description="Poll for trigger events and execute full workflow when event arrives")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"DraftWorkflowTriggerRunRequest", "DraftWorkflowTriggerRunRequest",
{ {
"node_id": fields.String(required=True, description="Node ID"), "node_id": fields.String(required=True, description="Node ID"),
}, },
) )
) )
@api.response(200, "Trigger event received and workflow executed successfully") @console_ns.response(200, "Trigger event received and workflow executed successfully")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@api.response(500, "Internal server error") @console_ns.response(500, "Internal server error")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -995,8 +995,9 @@ class DraftWorkflowTriggerRunApi(Resource):
Poll for trigger events and execute full workflow when event arrives Poll for trigger events and execute full workflow when event arrives
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False) "node_id", type=str, required=True, location="json", nullable=False
)
args = parser.parse_args() args = parser.parse_args()
node_id = args["node_id"] node_id = args["node_id"]
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -1044,12 +1045,12 @@ class DraftWorkflowTriggerNodeApi(Resource):
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run
""" """
@api.doc("poll_draft_workflow_trigger_node") @console_ns.doc("poll_draft_workflow_trigger_node")
@api.doc(description="Poll for trigger events and execute single node when event arrives") @console_ns.doc(description="Poll for trigger events and execute single node when event arrives")
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@api.response(200, "Trigger event received and node executed successfully") @console_ns.response(200, "Trigger event received and node executed successfully")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@api.response(500, "Internal server error") @console_ns.response(500, "Internal server error")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -1123,20 +1124,20 @@ class DraftWorkflowTriggerRunAllApi(Resource):
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run-all Path: /apps/<uuid:app_id>/workflows/draft/trigger/run-all
""" """
@api.doc("draft_workflow_trigger_run_all") @console_ns.doc("draft_workflow_trigger_run_all")
@api.doc(description="Full workflow debug when the start node is a trigger") @console_ns.doc(description="Full workflow debug when the start node is a trigger")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"DraftWorkflowTriggerRunAllRequest", "DraftWorkflowTriggerRunAllRequest",
{ {
"node_ids": fields.List(fields.String, required=True, description="Node IDs"), "node_ids": fields.List(fields.String, required=True, description="Node IDs"),
}, },
) )
) )
@api.response(200, "Workflow executed successfully") @console_ns.response(200, "Workflow executed successfully")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@api.response(500, "Internal server error") @console_ns.response(500, "Internal server error")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -1148,8 +1149,9 @@ class DraftWorkflowTriggerRunAllApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("node_ids", type=list, required=True, location="json", nullable=False) "node_ids", type=list, required=True, location="json", nullable=False
)
args = parser.parse_args() args = parser.parse_args()
node_ids = args["node_ids"] node_ids = args["node_ids"]
workflow_service = WorkflowService() workflow_service = WorkflowService()

View File

@ -3,7 +3,7 @@ from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.workflow.enums import WorkflowExecutionStatus from core.workflow.enums import WorkflowExecutionStatus
@ -17,10 +17,10 @@ from services.workflow_app_service import WorkflowAppService
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs") @console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
class WorkflowAppLogApi(Resource): class WorkflowAppLogApi(Resource):
@api.doc("get_workflow_app_logs") @console_ns.doc("get_workflow_app_logs")
@api.doc(description="Get workflow application execution logs") @console_ns.doc(description="Get workflow application execution logs")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.doc( @console_ns.doc(
params={ params={
"keyword": "Search keyword for filtering logs", "keyword": "Search keyword for filtering logs",
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)", "status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
@ -33,7 +33,7 @@ class WorkflowAppLogApi(Resource):
"limit": "Number of items per page (1-100)", "limit": "Number of items per page (1-100)",
} }
) )
@api.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_fields) @console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,17 +1,18 @@
import logging import logging
from typing import NoReturn from collections.abc import Callable
from functools import wraps
from typing import NoReturn, ParamSpec, TypeVar
from flask import Response from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
DraftWorkflowNotExist, DraftWorkflowNotExist,
) )
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError from controllers.web.error import InvalidArgumentError, NotFoundError
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from core.variables.segment_group import SegmentGroup from core.variables.segment_group import SegmentGroup
@ -21,8 +22,8 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIAB
from extensions.ext_database import db from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required from libs.login import login_required
from models import Account, App, AppMode from models import App, AppMode
from models.workflow import WorkflowDraftVariable from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@ -140,8 +141,11 @@ _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items), "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
} }
P = ParamSpec("P")
R = TypeVar("R")
def _api_prerequisite(f):
def _api_prerequisite(f: Callable[P, R]):
"""Common prerequisites for all draft workflow variable APIs. """Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied: It ensures the following conditions are satisfied:
@ -155,11 +159,10 @@ def _api_prerequisite(f):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def wrapper(*args, **kwargs): @wraps(f)
assert isinstance(current_user, Account) def wrapper(*args: P.args, **kwargs: P.kwargs):
if not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs) return f(*args, **kwargs)
return wrapper return wrapper
@ -167,11 +170,14 @@ def _api_prerequisite(f):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
class WorkflowVariableCollectionApi(Resource): class WorkflowVariableCollectionApi(Resource):
@api.doc("get_workflow_variables") @console_ns.expect(_create_pagination_parser())
@api.doc(description="Get draft workflow variables") @console_ns.doc("get_workflow_variables")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(description="Get draft workflow variables")
@api.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Workflow variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) @console_ns.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"})
@console_ns.response(
200, "Workflow variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS
)
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
def get(self, app_model: App): def get(self, app_model: App):
@ -200,9 +206,9 @@ class WorkflowVariableCollectionApi(Resource):
return workflow_vars return workflow_vars
@api.doc("delete_workflow_variables") @console_ns.doc("delete_workflow_variables")
@api.doc(description="Delete all draft workflow variables") @console_ns.doc(description="Delete all draft workflow variables")
@api.response(204, "Workflow variables deleted successfully") @console_ns.response(204, "Workflow variables deleted successfully")
@_api_prerequisite @_api_prerequisite
def delete(self, app_model: App): def delete(self, app_model: App):
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
@ -233,10 +239,10 @@ def validate_node_id(node_id: str) -> NoReturn | None:
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
class NodeVariableCollectionApi(Resource): class NodeVariableCollectionApi(Resource):
@api.doc("get_node_variables") @console_ns.doc("get_node_variables")
@api.doc(description="Get variables for a specific node") @console_ns.doc(description="Get variables for a specific node")
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@api.response(200, "Node variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @console_ns.response(200, "Node variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App, node_id: str): def get(self, app_model: App, node_id: str):
@ -249,9 +255,9 @@ class NodeVariableCollectionApi(Resource):
return node_vars return node_vars
@api.doc("delete_node_variables") @console_ns.doc("delete_node_variables")
@api.doc(description="Delete all variables for a specific node") @console_ns.doc(description="Delete all variables for a specific node")
@api.response(204, "Node variables deleted successfully") @console_ns.response(204, "Node variables deleted successfully")
@_api_prerequisite @_api_prerequisite
def delete(self, app_model: App, node_id: str): def delete(self, app_model: App, node_id: str):
validate_node_id(node_id) validate_node_id(node_id)
@ -266,11 +272,11 @@ class VariableApi(Resource):
_PATCH_NAME_FIELD = "name" _PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value" _PATCH_VALUE_FIELD = "value"
@api.doc("get_variable") @console_ns.doc("get_variable")
@api.doc(description="Get a specific workflow variable") @console_ns.doc(description="Get a specific workflow variable")
@api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"}) @console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
@api.response(200, "Variable retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS) @console_ns.response(200, "Variable retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
@api.response(404, "Variable not found") @console_ns.response(404, "Variable not found")
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def get(self, app_model: App, variable_id: str): def get(self, app_model: App, variable_id: str):
@ -284,10 +290,10 @@ class VariableApi(Resource):
raise NotFoundError(description=f"variable not found, id={variable_id}") raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable return variable
@api.doc("update_variable") @console_ns.doc("update_variable")
@api.doc(description="Update a workflow variable") @console_ns.doc(description="Update a workflow variable")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"UpdateVariableRequest", "UpdateVariableRequest",
{ {
"name": fields.String(description="Variable name"), "name": fields.String(description="Variable name"),
@ -295,8 +301,8 @@ class VariableApi(Resource):
}, },
) )
) )
@api.response(200, "Variable updated successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS) @console_ns.response(200, "Variable updated successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
@api.response(404, "Variable not found") @console_ns.response(404, "Variable not found")
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def patch(self, app_model: App, variable_id: str): def patch(self, app_model: App, variable_id: str):
@ -360,10 +366,10 @@ class VariableApi(Resource):
db.session.commit() db.session.commit()
return variable return variable
@api.doc("delete_variable") @console_ns.doc("delete_variable")
@api.doc(description="Delete a workflow variable") @console_ns.doc(description="Delete a workflow variable")
@api.response(204, "Variable deleted successfully") @console_ns.response(204, "Variable deleted successfully")
@api.response(404, "Variable not found") @console_ns.response(404, "Variable not found")
@_api_prerequisite @_api_prerequisite
def delete(self, app_model: App, variable_id: str): def delete(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
@ -381,12 +387,12 @@ class VariableApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class VariableResetApi(Resource): class VariableResetApi(Resource):
@api.doc("reset_variable") @console_ns.doc("reset_variable")
@api.doc(description="Reset a workflow variable to its default value") @console_ns.doc(description="Reset a workflow variable to its default value")
@api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"}) @console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
@api.response(200, "Variable reset successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS) @console_ns.response(200, "Variable reset successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
@api.response(204, "Variable reset (no content)") @console_ns.response(204, "Variable reset (no content)")
@api.response(404, "Variable not found") @console_ns.response(404, "Variable not found")
@_api_prerequisite @_api_prerequisite
def put(self, app_model: App, variable_id: str): def put(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
@ -429,11 +435,11 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/conversation-variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/conversation-variables")
class ConversationVariableCollectionApi(Resource): class ConversationVariableCollectionApi(Resource):
@api.doc("get_conversation_variables") @console_ns.doc("get_conversation_variables")
@api.doc(description="Get conversation variables for workflow") @console_ns.doc(description="Get conversation variables for workflow")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Conversation variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @console_ns.response(200, "Conversation variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@api.response(404, "Draft workflow not found") @console_ns.response(404, "Draft workflow not found")
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App): def get(self, app_model: App):
@ -451,10 +457,10 @@ class ConversationVariableCollectionApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
class SystemVariableCollectionApi(Resource): class SystemVariableCollectionApi(Resource):
@api.doc("get_system_variables") @console_ns.doc("get_system_variables")
@api.doc(description="Get system variables for workflow") @console_ns.doc(description="Get system variables for workflow")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "System variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @console_ns.response(200, "System variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App): def get(self, app_model: App):
@ -463,11 +469,11 @@ class SystemVariableCollectionApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/environment-variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/environment-variables")
class EnvironmentVariableCollectionApi(Resource): class EnvironmentVariableCollectionApi(Resource):
@api.doc("get_environment_variables") @console_ns.doc("get_environment_variables")
@api.doc(description="Get environment variables for workflow") @console_ns.doc(description="Get environment variables for workflow")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.response(200, "Environment variables retrieved successfully") @console_ns.response(200, "Environment variables retrieved successfully")
@api.response(404, "Draft workflow not found") @console_ns.response(404, "Draft workflow not found")
@_api_prerequisite @_api_prerequisite
def get(self, app_model: App): def get(self, app_model: App):
""" """

View File

@ -3,7 +3,7 @@ from typing import cast
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.workflow_run_fields import ( from fields.workflow_run_fields import (
@ -90,13 +90,17 @@ def _parse_workflow_run_count_args():
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
class AdvancedChatAppWorkflowRunListApi(Resource): class AdvancedChatAppWorkflowRunListApi(Resource):
@api.doc("get_advanced_chat_workflow_runs") @console_ns.doc("get_advanced_chat_workflow_runs")
@api.doc(description="Get advanced chat workflow run list") @console_ns.doc(description="Get advanced chat workflow run list")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) @console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}) @console_ns.doc(
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}) params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields) )
@console_ns.doc(
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -125,11 +129,13 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
class AdvancedChatAppWorkflowRunCountApi(Resource): class AdvancedChatAppWorkflowRunCountApi(Resource):
@api.doc("get_advanced_chat_workflow_runs_count") @console_ns.doc("get_advanced_chat_workflow_runs_count")
@api.doc(description="Get advanced chat workflow runs count statistics") @console_ns.doc(description="Get advanced chat workflow runs count statistics")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}) @console_ns.doc(
@api.doc( params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
)
@console_ns.doc(
params={ params={
"time_range": ( "time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
@ -137,8 +143,10 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
) )
} }
) )
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}) @console_ns.doc(
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields) params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -170,13 +178,17 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflow-runs") @console_ns.route("/apps/<uuid:app_id>/workflow-runs")
class WorkflowRunListApi(Resource): class WorkflowRunListApi(Resource):
@api.doc("get_workflow_runs") @console_ns.doc("get_workflow_runs")
@api.doc(description="Get workflow run list") @console_ns.doc(description="Get workflow run list")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) @console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}) @console_ns.doc(
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}) params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields) )
@console_ns.doc(
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -205,11 +217,13 @@ class WorkflowRunListApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count") @console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
class WorkflowRunCountApi(Resource): class WorkflowRunCountApi(Resource):
@api.doc("get_workflow_runs_count") @console_ns.doc("get_workflow_runs_count")
@api.doc(description="Get workflow runs count statistics") @console_ns.doc(description="Get workflow runs count statistics")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}) @console_ns.doc(
@api.doc( params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
)
@console_ns.doc(
params={ params={
"time_range": ( "time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
@ -217,8 +231,10 @@ class WorkflowRunCountApi(Resource):
) )
} }
) )
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}) @console_ns.doc(
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields) params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -250,11 +266,11 @@ class WorkflowRunCountApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>") @console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>")
class WorkflowRunDetailApi(Resource): class WorkflowRunDetailApi(Resource):
@api.doc("get_workflow_run_detail") @console_ns.doc("get_workflow_run_detail")
@api.doc(description="Get workflow run detail") @console_ns.doc(description="Get workflow run detail")
@api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@api.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_fields) @console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_fields)
@api.response(404, "Workflow run not found") @console_ns.response(404, "Workflow run not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -274,11 +290,11 @@ class WorkflowRunDetailApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions") @console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions")
class WorkflowRunNodeExecutionListApi(Resource): class WorkflowRunNodeExecutionListApi(Resource):
@api.doc("get_workflow_run_node_executions") @console_ns.doc("get_workflow_run_node_executions")
@api.doc(description="Get workflow run node execution list") @console_ns.doc(description="Get workflow run node execution list")
@api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@api.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_fields) @console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_fields)
@api.response(404, "Workflow run not found") @console_ns.response(404, "Workflow run not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -2,7 +2,7 @@ from flask import abort, jsonify
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
@ -21,11 +21,13 @@ class WorkflowDailyRunsStatistic(Resource):
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_daily_runs_statistic") @console_ns.doc("get_workflow_daily_runs_statistic")
@api.doc(description="Get workflow daily runs statistics") @console_ns.doc(description="Get workflow daily runs statistics")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}) @console_ns.doc(
@api.response(200, "Daily runs statistics retrieved successfully") params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily runs statistics retrieved successfully")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -66,11 +68,13 @@ class WorkflowDailyTerminalsStatistic(Resource):
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_daily_terminals_statistic") @console_ns.doc("get_workflow_daily_terminals_statistic")
@api.doc(description="Get workflow daily terminals statistics") @console_ns.doc(description="Get workflow daily terminals statistics")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}) @console_ns.doc(
@api.response(200, "Daily terminals statistics retrieved successfully") params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -111,11 +115,13 @@ class WorkflowDailyTokenCostStatistic(Resource):
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_daily_token_cost_statistic") @console_ns.doc("get_workflow_daily_token_cost_statistic")
@api.doc(description="Get workflow daily token cost statistics") @console_ns.doc(description="Get workflow daily token cost statistics")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}) @console_ns.doc(
@api.response(200, "Daily token cost statistics retrieved successfully") params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -156,11 +162,13 @@ class WorkflowAverageAppInteractionStatistic(Resource):
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@api.doc("get_workflow_average_app_interaction_statistic") @console_ns.doc("get_workflow_average_app_interaction_statistic")
@api.doc(description="Get workflow average app interaction statistics") @console_ns.doc(description="Get workflow average app interaction statistics")
@api.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}) @console_ns.doc(
@api.response(200, "Average app interaction statistics retrieved successfully") params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -3,12 +3,12 @@ import logging
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import NotFound
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required from libs.login import current_user, login_required
@ -29,8 +29,7 @@ class WebhookTriggerApi(Resource):
@marshal_with(webhook_trigger_fields) @marshal_with(webhook_trigger_fields)
def get(self, app_model: App): def get(self, app_model: App):
"""Get webhook trigger for a node""" """Get webhook trigger for a node"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, help="Node ID is required")
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
args = parser.parse_args() args = parser.parse_args()
node_id = str(args["node_id"]) node_id = str(args["node_id"])
@ -95,19 +94,19 @@ class AppTriggerEnableApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.WORKFLOW) @get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_fields) @marshal_with(trigger_fields)
def post(self, app_model: App): def post(self, app_model: App):
"""Update app trigger (enable/disable)""" """Update app trigger (enable/disable)"""
parser = reqparse.RequestParser() parser = (
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json") .add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None assert current_user.current_tenant_id is not None
if not current_user.has_edit_permission:
raise Forbidden()
trigger_id = args["trigger_id"] trigger_id = args["trigger_id"]
@ -140,6 +139,6 @@ class AppTriggerEnableApi(Resource):
return trigger return trigger
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook") console_ns.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers") console_ns.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable") console_ns.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")

View File

@ -2,7 +2,7 @@ from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
@ -20,13 +20,13 @@ active_check_parser = (
@console_ns.route("/activate/check") @console_ns.route("/activate/check")
class ActivateCheckApi(Resource): class ActivateCheckApi(Resource):
@api.doc("check_activation_token") @console_ns.doc("check_activation_token")
@api.doc(description="Check if activation token is valid") @console_ns.doc(description="Check if activation token is valid")
@api.expect(active_check_parser) @console_ns.expect(active_check_parser)
@api.response( @console_ns.response(
200, 200,
"Success", "Success",
api.model( console_ns.model(
"ActivationCheckResponse", "ActivationCheckResponse",
{ {
"is_valid": fields.Boolean(description="Whether token is valid"), "is_valid": fields.Boolean(description="Whether token is valid"),
@ -69,13 +69,13 @@ active_parser = (
@console_ns.route("/activate") @console_ns.route("/activate")
class ActivateApi(Resource): class ActivateApi(Resource):
@api.doc("activate_account") @console_ns.doc("activate_account")
@api.doc(description="Activate account with invitation token") @console_ns.doc(description="Activate account with invitation token")
@api.expect(active_parser) @console_ns.expect(active_parser)
@api.response( @console_ns.response(
200, 200,
"Account activated successfully", "Account activated successfully",
api.model( console_ns.model(
"ActivationResponse", "ActivationResponse",
{ {
"result": fields.String(description="Operation result"), "result": fields.String(description="Operation result"),
@ -83,7 +83,7 @@ class ActivateApi(Resource):
}, },
), ),
) )
@api.response(400, "Already activated or invalid token") @console_ns.response(400, "Already activated or invalid token")
def post(self): def post(self):
args = active_parser.parse_args() args = active_parser.parse_args()

View File

@ -1,8 +1,8 @@
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError from controllers.console.auth.error import ApiKeyAuthFailedError
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService from services.auth.api_key_auth_service import ApiKeyAuthService
@ -39,12 +39,10 @@ class ApiKeyAuthDataSourceBinding(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@is_admin_or_owner_required
def post(self): def post(self):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = ( parser = (
reqparse.RequestParser() reqparse.RequestParser()
.add_argument("category", type=str, required=True, nullable=False, location="json") .add_argument("category", type=str, required=True, nullable=False, location="json")
@ -65,12 +63,10 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@is_admin_or_owner_required
def delete(self, binding_id): def delete(self, binding_id):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id) ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)

View File

@ -3,11 +3,11 @@ import logging
import httpx import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restx import Resource, fields from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api, console_ns from controllers.console import console_ns
from libs.login import current_account_with_tenant, login_required from controllers.console.wraps import is_admin_or_owner_required
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required from ..wraps import account_initialization_required, setup_required
@ -29,24 +29,22 @@ def get_oauth_providers():
@console_ns.route("/oauth/data-source/<string:provider>") @console_ns.route("/oauth/data-source/<string:provider>")
class OAuthDataSource(Resource): class OAuthDataSource(Resource):
@api.doc("oauth_data_source") @console_ns.doc("oauth_data_source")
@api.doc(description="Get OAuth authorization URL for data source provider") @console_ns.doc(description="Get OAuth authorization URL for data source provider")
@api.doc(params={"provider": "Data source provider name (notion)"}) @console_ns.doc(params={"provider": "Data source provider name (notion)"})
@api.response( @console_ns.response(
200, 200,
"Authorization URL or internal setup success", "Authorization URL or internal setup success",
api.model( console_ns.model(
"OAuthDataSourceResponse", "OAuthDataSourceResponse",
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")}, {"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
), ),
) )
@api.response(400, "Invalid provider") @console_ns.response(400, "Invalid provider")
@api.response(403, "Admin privileges required") @console_ns.response(403, "Admin privileges required")
@is_admin_or_owner_required
def get(self, provider: str): def get(self, provider: str):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
@ -65,17 +63,17 @@ class OAuthDataSource(Resource):
@console_ns.route("/oauth/data-source/callback/<string:provider>") @console_ns.route("/oauth/data-source/callback/<string:provider>")
class OAuthDataSourceCallback(Resource): class OAuthDataSourceCallback(Resource):
@api.doc("oauth_data_source_callback") @console_ns.doc("oauth_data_source_callback")
@api.doc(description="Handle OAuth callback from data source provider") @console_ns.doc(description="Handle OAuth callback from data source provider")
@api.doc( @console_ns.doc(
params={ params={
"provider": "Data source provider name (notion)", "provider": "Data source provider name (notion)",
"code": "Authorization code from OAuth provider", "code": "Authorization code from OAuth provider",
"error": "Error message from OAuth provider", "error": "Error message from OAuth provider",
} }
) )
@api.response(302, "Redirect to console with result") @console_ns.response(302, "Redirect to console with result")
@api.response(400, "Invalid provider") @console_ns.response(400, "Invalid provider")
def get(self, provider: str): def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():
@ -96,17 +94,17 @@ class OAuthDataSourceCallback(Resource):
@console_ns.route("/oauth/data-source/binding/<string:provider>") @console_ns.route("/oauth/data-source/binding/<string:provider>")
class OAuthDataSourceBinding(Resource): class OAuthDataSourceBinding(Resource):
@api.doc("oauth_data_source_binding") @console_ns.doc("oauth_data_source_binding")
@api.doc(description="Bind OAuth data source with authorization code") @console_ns.doc(description="Bind OAuth data source with authorization code")
@api.doc( @console_ns.doc(
params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"} params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"}
) )
@api.response( @console_ns.response(
200, 200,
"Data source binding success", "Data source binding success",
api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}), console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
) )
@api.response(400, "Invalid provider or code") @console_ns.response(400, "Invalid provider or code")
def get(self, provider: str): def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():
@ -130,15 +128,15 @@ class OAuthDataSourceBinding(Resource):
@console_ns.route("/oauth/data-source/<string:provider>/<uuid:binding_id>/sync") @console_ns.route("/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
class OAuthDataSourceSync(Resource): class OAuthDataSourceSync(Resource):
@api.doc("oauth_data_source_sync") @console_ns.doc("oauth_data_source_sync")
@api.doc(description="Sync data from OAuth data source") @console_ns.doc(description="Sync data from OAuth data source")
@api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"}) @console_ns.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
@api.response( @console_ns.response(
200, 200,
"Data source sync success", "Data source sync success",
api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}), console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
) )
@api.response(400, "Invalid provider or sync failed") @console_ns.response(400, "Invalid provider or sync failed")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -6,7 +6,7 @@ from flask_restx import Resource, fields, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
EmailCodeError, EmailCodeError,
EmailPasswordResetLimitError, EmailPasswordResetLimitError,
@ -27,10 +27,10 @@ from services.feature_service import FeatureService
@console_ns.route("/forgot-password") @console_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordSendEmailApi(Resource):
@api.doc("send_forgot_password_email") @console_ns.doc("send_forgot_password_email")
@api.doc(description="Send password reset email") @console_ns.doc(description="Send password reset email")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"ForgotPasswordEmailRequest", "ForgotPasswordEmailRequest",
{ {
"email": fields.String(required=True, description="Email address"), "email": fields.String(required=True, description="Email address"),
@ -38,10 +38,10 @@ class ForgotPasswordSendEmailApi(Resource):
}, },
) )
) )
@api.response( @console_ns.response(
200, 200,
"Email sent successfully", "Email sent successfully",
api.model( console_ns.model(
"ForgotPasswordEmailResponse", "ForgotPasswordEmailResponse",
{ {
"result": fields.String(description="Operation result"), "result": fields.String(description="Operation result"),
@ -50,7 +50,7 @@ class ForgotPasswordSendEmailApi(Resource):
}, },
), ),
) )
@api.response(400, "Invalid email or rate limit exceeded") @console_ns.response(400, "Invalid email or rate limit exceeded")
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
@ -85,10 +85,10 @@ class ForgotPasswordSendEmailApi(Resource):
@console_ns.route("/forgot-password/validity") @console_ns.route("/forgot-password/validity")
class ForgotPasswordCheckApi(Resource): class ForgotPasswordCheckApi(Resource):
@api.doc("check_forgot_password_code") @console_ns.doc("check_forgot_password_code")
@api.doc(description="Verify password reset code") @console_ns.doc(description="Verify password reset code")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"ForgotPasswordCheckRequest", "ForgotPasswordCheckRequest",
{ {
"email": fields.String(required=True, description="Email address"), "email": fields.String(required=True, description="Email address"),
@ -97,10 +97,10 @@ class ForgotPasswordCheckApi(Resource):
}, },
) )
) )
@api.response( @console_ns.response(
200, 200,
"Code verified successfully", "Code verified successfully",
api.model( console_ns.model(
"ForgotPasswordCheckResponse", "ForgotPasswordCheckResponse",
{ {
"is_valid": fields.Boolean(description="Whether code is valid"), "is_valid": fields.Boolean(description="Whether code is valid"),
@ -109,7 +109,7 @@ class ForgotPasswordCheckApi(Resource):
}, },
), ),
) )
@api.response(400, "Invalid code or token") @console_ns.response(400, "Invalid code or token")
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
@ -152,10 +152,10 @@ class ForgotPasswordCheckApi(Resource):
@console_ns.route("/forgot-password/resets") @console_ns.route("/forgot-password/resets")
class ForgotPasswordResetApi(Resource): class ForgotPasswordResetApi(Resource):
@api.doc("reset_password") @console_ns.doc("reset_password")
@api.doc(description="Reset password with verification token") @console_ns.doc(description="Reset password with verification token")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"ForgotPasswordResetRequest", "ForgotPasswordResetRequest",
{ {
"token": fields.String(required=True, description="Verification token"), "token": fields.String(required=True, description="Verification token"),
@ -164,12 +164,12 @@ class ForgotPasswordResetApi(Resource):
}, },
) )
) )
@api.response( @console_ns.response(
200, 200,
"Password reset successfully", "Password reset successfully",
api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}), console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
) )
@api.response(400, "Invalid token or password mismatch") @console_ns.response(400, "Invalid token or password mismatch")
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):

View File

@ -26,7 +26,7 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
from services.feature_service import FeatureService from services.feature_service import FeatureService
from .. import api, console_ns from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,11 +56,13 @@ def get_oauth_providers():
@console_ns.route("/oauth/login/<provider>") @console_ns.route("/oauth/login/<provider>")
class OAuthLogin(Resource): class OAuthLogin(Resource):
@api.doc("oauth_login") @console_ns.doc("oauth_login")
@api.doc(description="Initiate OAuth login process") @console_ns.doc(description="Initiate OAuth login process")
@api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"}) @console_ns.doc(
@api.response(302, "Redirect to OAuth authorization URL") params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"}
@api.response(400, "Invalid provider") )
@console_ns.response(302, "Redirect to OAuth authorization URL")
@console_ns.response(400, "Invalid provider")
def get(self, provider: str): def get(self, provider: str):
invite_token = request.args.get("invite_token") or None invite_token = request.args.get("invite_token") or None
OAUTH_PROVIDERS = get_oauth_providers() OAUTH_PROVIDERS = get_oauth_providers()
@ -75,17 +77,17 @@ class OAuthLogin(Resource):
@console_ns.route("/oauth/authorize/<provider>") @console_ns.route("/oauth/authorize/<provider>")
class OAuthCallback(Resource): class OAuthCallback(Resource):
@api.doc("oauth_callback") @console_ns.doc("oauth_callback")
@api.doc(description="Handle OAuth callback and complete login process") @console_ns.doc(description="Handle OAuth callback and complete login process")
@api.doc( @console_ns.doc(
params={ params={
"provider": "OAuth provider name (github/google)", "provider": "OAuth provider name (github/google)",
"code": "Authorization code from OAuth provider", "code": "Authorization code from OAuth provider",
"state": "Optional state parameter (used for invite token)", "state": "Optional state parameter (used for invite token)",
} }
) )
@api.response(302, "Redirect to console with access token") @console_ns.response(302, "Redirect to console with access token")
@api.response(400, "OAuth process failed") @console_ns.response(400, "OAuth process failed")
def get(self, provider: str): def get(self, provider: str):
OAUTH_PROVIDERS = get_oauth_providers() OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():

View File

@ -1,4 +1,7 @@
from flask_restx import Resource, reqparse import base64
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
@ -41,3 +44,37 @@ class Invoices(Resource):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_invoices(current_user.email, current_tenant_id) return BillingService.get_invoices(current_user.email, current_tenant_id)
@console_ns.route("/billing/partners/<string:partner_key>/tenants")
class PartnerTenants(Resource):
@console_ns.doc("sync_partner_tenants_bindings")
@console_ns.doc(description="Sync partner tenants bindings")
@console_ns.doc(params={"partner_key": "Partner key"})
@console_ns.expect(
console_ns.model(
"SyncPartnerTenantsBindingsRequest",
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
)
)
@console_ns.response(200, "Tenants synced to partner successfully")
@console_ns.response(400, "Invalid partner information")
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def put(self, partner_key: str):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
args = parser.parse_args()
try:
click_id = args["click_id"]
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
except Exception:
raise BadRequest("Invalid partner_key")
if not click_id or not decoded_partner_key or not current_user.id:
raise BadRequest("Invalid partner information")
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)

View File

@ -7,7 +7,7 @@ from werkzeug.exceptions import Forbidden, NotFound
import services import services
from configs import dify_config from configs import dify_config
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.apikey import api_key_fields, api_key_list
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
@ -15,6 +15,7 @@ from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_rate_limit_check, cloud_edition_billing_rate_limit_check,
enterprise_license_required, enterprise_license_required,
is_admin_or_owner_required,
setup_required, setup_required,
) )
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
@ -118,9 +119,9 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
@console_ns.route("/datasets") @console_ns.route("/datasets")
class DatasetListApi(Resource): class DatasetListApi(Resource):
@api.doc("get_datasets") @console_ns.doc("get_datasets")
@api.doc(description="Get list of datasets") @console_ns.doc(description="Get list of datasets")
@api.doc( @console_ns.doc(
params={ params={
"page": "Page number (default: 1)", "page": "Page number (default: 1)",
"limit": "Number of items per page (default: 20)", "limit": "Number of items per page (default: 20)",
@ -130,7 +131,7 @@ class DatasetListApi(Resource):
"include_all": "Include all datasets (default: false)", "include_all": "Include all datasets (default: false)",
} }
) )
@api.response(200, "Datasets retrieved successfully") @console_ns.response(200, "Datasets retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -183,10 +184,10 @@ class DatasetListApi(Resource):
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200 return response, 200
@api.doc("create_dataset") @console_ns.doc("create_dataset")
@api.doc(description="Create a new dataset") @console_ns.doc(description="Create a new dataset")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"CreateDatasetRequest", "CreateDatasetRequest",
{ {
"name": fields.String(required=True, description="Dataset name (1-40 characters)"), "name": fields.String(required=True, description="Dataset name (1-40 characters)"),
@ -199,8 +200,8 @@ class DatasetListApi(Resource):
}, },
) )
) )
@api.response(201, "Dataset created successfully") @console_ns.response(201, "Dataset created successfully")
@api.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -278,12 +279,12 @@ class DatasetListApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>") @console_ns.route("/datasets/<uuid:dataset_id>")
class DatasetApi(Resource): class DatasetApi(Resource):
@api.doc("get_dataset") @console_ns.doc("get_dataset")
@api.doc(description="Get dataset details") @console_ns.doc(description="Get dataset details")
@api.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@api.response(200, "Dataset retrieved successfully", dataset_detail_fields) @console_ns.response(200, "Dataset retrieved successfully", dataset_detail_fields)
@api.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -327,10 +328,10 @@ class DatasetApi(Resource):
return data, 200 return data, 200
@api.doc("update_dataset") @console_ns.doc("update_dataset")
@api.doc(description="Update dataset details") @console_ns.doc(description="Update dataset details")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"UpdateDatasetRequest", "UpdateDatasetRequest",
{ {
"name": fields.String(description="Dataset name"), "name": fields.String(description="Dataset name"),
@ -341,9 +342,9 @@ class DatasetApi(Resource):
}, },
) )
) )
@api.response(200, "Dataset updated successfully", dataset_detail_fields) @console_ns.response(200, "Dataset updated successfully", dataset_detail_fields)
@api.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -487,10 +488,10 @@ class DatasetApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/use-check") @console_ns.route("/datasets/<uuid:dataset_id>/use-check")
class DatasetUseCheckApi(Resource): class DatasetUseCheckApi(Resource):
@api.doc("check_dataset_use") @console_ns.doc("check_dataset_use")
@api.doc(description="Check if dataset is in use") @console_ns.doc(description="Check if dataset is in use")
@api.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@api.response(200, "Dataset use status retrieved successfully") @console_ns.response(200, "Dataset use status retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -503,10 +504,10 @@ class DatasetUseCheckApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/queries") @console_ns.route("/datasets/<uuid:dataset_id>/queries")
class DatasetQueryApi(Resource): class DatasetQueryApi(Resource):
@api.doc("get_dataset_queries") @console_ns.doc("get_dataset_queries")
@api.doc(description="Get dataset query history") @console_ns.doc(description="Get dataset query history")
@api.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@api.response(200, "Query history retrieved successfully", dataset_query_detail_fields) @console_ns.response(200, "Query history retrieved successfully", dataset_query_detail_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -539,9 +540,9 @@ class DatasetQueryApi(Resource):
@console_ns.route("/datasets/indexing-estimate") @console_ns.route("/datasets/indexing-estimate")
class DatasetIndexingEstimateApi(Resource): class DatasetIndexingEstimateApi(Resource):
@api.doc("estimate_dataset_indexing") @console_ns.doc("estimate_dataset_indexing")
@api.doc(description="Estimate dataset indexing cost") @console_ns.doc(description="Estimate dataset indexing cost")
@api.response(200, "Indexing estimate calculated successfully") @console_ns.response(200, "Indexing estimate calculated successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -649,10 +650,10 @@ class DatasetIndexingEstimateApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/related-apps") @console_ns.route("/datasets/<uuid:dataset_id>/related-apps")
class DatasetRelatedAppListApi(Resource): class DatasetRelatedAppListApi(Resource):
@api.doc("get_dataset_related_apps") @console_ns.doc("get_dataset_related_apps")
@api.doc(description="Get applications related to dataset") @console_ns.doc(description="Get applications related to dataset")
@api.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@api.response(200, "Related apps retrieved successfully", related_app_list) @console_ns.response(200, "Related apps retrieved successfully", related_app_list)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -682,10 +683,10 @@ class DatasetRelatedAppListApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/indexing-status") @console_ns.route("/datasets/<uuid:dataset_id>/indexing-status")
class DatasetIndexingStatusApi(Resource): class DatasetIndexingStatusApi(Resource):
@api.doc("get_dataset_indexing_status") @console_ns.doc("get_dataset_indexing_status")
@api.doc(description="Get dataset indexing status") @console_ns.doc(description="Get dataset indexing status")
@api.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@api.response(200, "Indexing status retrieved successfully") @console_ns.response(200, "Indexing status retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -737,9 +738,9 @@ class DatasetApiKeyApi(Resource):
token_prefix = "dataset-" token_prefix = "dataset-"
resource_type = "dataset" resource_type = "dataset"
@api.doc("get_dataset_api_keys") @console_ns.doc("get_dataset_api_keys")
@api.doc(description="Get dataset API keys") @console_ns.doc(description="Get dataset API keys")
@api.response(200, "API keys retrieved successfully", api_key_list) @console_ns.response(200, "API keys retrieved successfully", api_key_list)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -753,13 +754,11 @@ class DatasetApiKeyApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
@marshal_with(api_key_fields) @marshal_with(api_key_fields)
def post(self): def post(self):
# The role of the current user in the ta table must be admin or owner _, current_tenant_id = current_account_with_tenant()
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
@ -768,7 +767,7 @@ class DatasetApiKeyApi(Resource):
) )
if current_key_count >= self.max_keys: if current_key_count >= self.max_keys:
api.abort( console_ns.abort(
400, 400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.", message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded", code="max_keys_exceeded",
@ -788,21 +787,17 @@ class DatasetApiKeyApi(Resource):
class DatasetApiDeleteApi(Resource): class DatasetApiDeleteApi(Resource):
resource_type = "dataset" resource_type = "dataset"
@api.doc("delete_dataset_api_key") @console_ns.doc("delete_dataset_api_key")
@api.doc(description="Delete dataset API key") @console_ns.doc(description="Delete dataset API key")
@api.doc(params={"api_key_id": "API key ID"}) @console_ns.doc(params={"api_key_id": "API key ID"})
@api.response(204, "API key deleted successfully") @console_ns.response(204, "API key deleted successfully")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def delete(self, api_key_id): def delete(self, api_key_id):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id) api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
key = ( key = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where( .where(
@ -814,7 +809,7 @@ class DatasetApiDeleteApi(Resource):
) )
if key is None: if key is None:
api.abort(404, message="API key not found") console_ns.abort(404, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit() db.session.commit()
@ -837,9 +832,9 @@ class DatasetEnableApiApi(Resource):
@console_ns.route("/datasets/api-base-info") @console_ns.route("/datasets/api-base-info")
class DatasetApiBaseUrlApi(Resource): class DatasetApiBaseUrlApi(Resource):
@api.doc("get_dataset_api_base_info") @console_ns.doc("get_dataset_api_base_info")
@api.doc(description="Get dataset API base information") @console_ns.doc(description="Get dataset API base information")
@api.response(200, "API base info retrieved successfully") @console_ns.response(200, "API base info retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -849,9 +844,9 @@ class DatasetApiBaseUrlApi(Resource):
@console_ns.route("/datasets/retrieval-setting") @console_ns.route("/datasets/retrieval-setting")
class DatasetRetrievalSettingApi(Resource): class DatasetRetrievalSettingApi(Resource):
@api.doc("get_dataset_retrieval_setting") @console_ns.doc("get_dataset_retrieval_setting")
@api.doc(description="Get dataset retrieval settings") @console_ns.doc(description="Get dataset retrieval settings")
@api.response(200, "Retrieval settings retrieved successfully") @console_ns.response(200, "Retrieval settings retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -862,10 +857,10 @@ class DatasetRetrievalSettingApi(Resource):
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>") @console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
class DatasetRetrievalSettingMockApi(Resource): class DatasetRetrievalSettingMockApi(Resource):
@api.doc("get_dataset_retrieval_setting_mock") @console_ns.doc("get_dataset_retrieval_setting_mock")
@api.doc(description="Get mock dataset retrieval settings by vector type") @console_ns.doc(description="Get mock dataset retrieval settings by vector type")
@api.doc(params={"vector_type": "Vector store type"}) @console_ns.doc(params={"vector_type": "Vector store type"})
@api.response(200, "Mock retrieval settings retrieved successfully") @console_ns.response(200, "Mock retrieval settings retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -875,11 +870,11 @@ class DatasetRetrievalSettingMockApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs") @console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
class DatasetErrorDocs(Resource): class DatasetErrorDocs(Resource):
@api.doc("get_dataset_error_docs") @console_ns.doc("get_dataset_error_docs")
@api.doc(description="Get dataset error documents") @console_ns.doc(description="Get dataset error documents")
@api.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@api.response(200, "Error documents retrieved successfully") @console_ns.response(200, "Error documents retrieved successfully")
@api.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -895,12 +890,12 @@ class DatasetErrorDocs(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/permission-part-users") @console_ns.route("/datasets/<uuid:dataset_id>/permission-part-users")
class DatasetPermissionUserListApi(Resource): class DatasetPermissionUserListApi(Resource):
@api.doc("get_dataset_permission_users") @console_ns.doc("get_dataset_permission_users")
@api.doc(description="Get dataset permission user list") @console_ns.doc(description="Get dataset permission user list")
@api.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@api.response(200, "Permission users retrieved successfully") @console_ns.response(200, "Permission users retrieved successfully")
@api.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -924,11 +919,11 @@ class DatasetPermissionUserListApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/auto-disable-logs") @console_ns.route("/datasets/<uuid:dataset_id>/auto-disable-logs")
class DatasetAutoDisableLogApi(Resource): class DatasetAutoDisableLogApi(Resource):
@api.doc("get_dataset_auto_disable_logs") @console_ns.doc("get_dataset_auto_disable_logs")
@api.doc(description="Get dataset auto disable logs") @console_ns.doc(description="Get dataset auto disable logs")
@api.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@api.response(200, "Auto disable logs retrieved successfully") @console_ns.response(200, "Auto disable logs retrieved successfully")
@api.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -11,7 +11,7 @@ from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError, ProviderNotInitializeError,
@ -104,10 +104,10 @@ class DocumentResource(Resource):
@console_ns.route("/datasets/process-rule") @console_ns.route("/datasets/process-rule")
class GetProcessRuleApi(Resource): class GetProcessRuleApi(Resource):
@api.doc("get_process_rule") @console_ns.doc("get_process_rule")
@api.doc(description="Get dataset document processing rules") @console_ns.doc(description="Get dataset document processing rules")
@api.doc(params={"document_id": "Document ID (optional)"}) @console_ns.doc(params={"document_id": "Document ID (optional)"})
@api.response(200, "Process rules retrieved successfully") @console_ns.response(200, "Process rules retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -152,9 +152,9 @@ class GetProcessRuleApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/documents") @console_ns.route("/datasets/<uuid:dataset_id>/documents")
class DatasetDocumentListApi(Resource): class DatasetDocumentListApi(Resource):
@api.doc("get_dataset_documents") @console_ns.doc("get_dataset_documents")
@api.doc(description="Get documents in a dataset") @console_ns.doc(description="Get documents in a dataset")
@api.doc( @console_ns.doc(
params={ params={
"dataset_id": "Dataset ID", "dataset_id": "Dataset ID",
"page": "Page number (default: 1)", "page": "Page number (default: 1)",
@ -162,9 +162,10 @@ class DatasetDocumentListApi(Resource):
"keyword": "Search keyword", "keyword": "Search keyword",
"sort": "Sort order (default: -created_at)", "sort": "Sort order (default: -created_at)",
"fetch": "Fetch full details (default: false)", "fetch": "Fetch full details (default: false)",
"status": "Filter documents by display status",
} }
) )
@api.response(200, "Documents retrieved successfully") @console_ns.response(200, "Documents retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -175,6 +176,7 @@ class DatasetDocumentListApi(Resource):
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str) search = request.args.get("keyword", default=None, type=str)
sort = request.args.get("sort", default="-created_at", type=str) sort = request.args.get("sort", default="-created_at", type=str)
status = request.args.get("status", default=None, type=str)
# "yes", "true", "t", "y", "1" convert to True, while others convert to False. # "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try: try:
fetch_val = request.args.get("fetch", default="false") fetch_val = request.args.get("fetch", default="false")
@ -203,6 +205,9 @@ class DatasetDocumentListApi(Resource):
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id) query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
if status:
query = DocumentService.apply_display_status_filter(query, status)
if search: if search:
search = f"%{search}%" search = f"%{search}%"
query = query.where(Document.name.like(search)) query = query.where(Document.name.like(search))
@ -352,10 +357,10 @@ class DatasetDocumentListApi(Resource):
@console_ns.route("/datasets/init") @console_ns.route("/datasets/init")
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
@api.doc("init_dataset") @console_ns.doc("init_dataset")
@api.doc(description="Initialize dataset with documents") @console_ns.doc(description="Initialize dataset with documents")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"DatasetInitRequest", "DatasetInitRequest",
{ {
"upload_file_id": fields.String(required=True, description="Upload file ID"), "upload_file_id": fields.String(required=True, description="Upload file ID"),
@ -365,8 +370,8 @@ class DatasetInitApi(Resource):
}, },
) )
) )
@api.response(201, "Dataset initialized successfully", dataset_and_document_fields) @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_fields)
@api.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -441,12 +446,12 @@ class DatasetInitApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
class DocumentIndexingEstimateApi(DocumentResource): class DocumentIndexingEstimateApi(DocumentResource):
@api.doc("estimate_document_indexing") @console_ns.doc("estimate_document_indexing")
@api.doc(description="Estimate document indexing cost") @console_ns.doc(description="Estimate document indexing cost")
@api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@api.response(200, "Indexing estimate calculated successfully") @console_ns.response(200, "Indexing estimate calculated successfully")
@api.response(404, "Document not found") @console_ns.response(404, "Document not found")
@api.response(400, "Document already finished") @console_ns.response(400, "Document already finished")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -656,11 +661,11 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
class DocumentIndexingStatusApi(DocumentResource): class DocumentIndexingStatusApi(DocumentResource):
@api.doc("get_document_indexing_status") @console_ns.doc("get_document_indexing_status")
@api.doc(description="Get document indexing status") @console_ns.doc(description="Get document indexing status")
@api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@api.response(200, "Indexing status retrieved successfully") @console_ns.response(200, "Indexing status retrieved successfully")
@api.response(404, "Document not found") @console_ns.response(404, "Document not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -706,17 +711,17 @@ class DocumentIndexingStatusApi(DocumentResource):
class DocumentApi(DocumentResource): class DocumentApi(DocumentResource):
METADATA_CHOICES = {"all", "only", "without"} METADATA_CHOICES = {"all", "only", "without"}
@api.doc("get_document") @console_ns.doc("get_document")
@api.doc(description="Get document details") @console_ns.doc(description="Get document details")
@api.doc( @console_ns.doc(
params={ params={
"dataset_id": "Dataset ID", "dataset_id": "Dataset ID",
"document_id": "Document ID", "document_id": "Document ID",
"metadata": "Metadata inclusion (all/only/without)", "metadata": "Metadata inclusion (all/only/without)",
} }
) )
@api.response(200, "Document retrieved successfully") @console_ns.response(200, "Document retrieved successfully")
@api.response(404, "Document not found") @console_ns.response(404, "Document not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -827,14 +832,14 @@ class DocumentApi(DocumentResource):
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>")
class DocumentProcessingApi(DocumentResource): class DocumentProcessingApi(DocumentResource):
@api.doc("update_document_processing") @console_ns.doc("update_document_processing")
@api.doc(description="Update document processing status (pause/resume)") @console_ns.doc(description="Update document processing status (pause/resume)")
@api.doc( @console_ns.doc(
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "action": "Action to perform (pause/resume)"} params={"dataset_id": "Dataset ID", "document_id": "Document ID", "action": "Action to perform (pause/resume)"}
) )
@api.response(200, "Processing status updated successfully") @console_ns.response(200, "Processing status updated successfully")
@api.response(404, "Document not found") @console_ns.response(404, "Document not found")
@api.response(400, "Invalid action") @console_ns.response(400, "Invalid action")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -872,11 +877,11 @@ class DocumentProcessingApi(DocumentResource):
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
class DocumentMetadataApi(DocumentResource): class DocumentMetadataApi(DocumentResource):
@api.doc("update_document_metadata") @console_ns.doc("update_document_metadata")
@api.doc(description="Update document metadata") @console_ns.doc(description="Update document metadata")
@api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"UpdateDocumentMetadataRequest", "UpdateDocumentMetadataRequest",
{ {
"doc_type": fields.String(description="Document type"), "doc_type": fields.String(description="Document type"),
@ -884,9 +889,9 @@ class DocumentMetadataApi(DocumentResource):
}, },
) )
) )
@api.response(200, "Document metadata updated successfully") @console_ns.response(200, "Document metadata updated successfully")
@api.response(404, "Document not found") @console_ns.response(404, "Document not found")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -3,9 +3,9 @@ from flask_restx import Resource, fields, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
@ -22,16 +22,16 @@ def _validate_name(name: str) -> str:
@console_ns.route("/datasets/external-knowledge-api") @console_ns.route("/datasets/external-knowledge-api")
class ExternalApiTemplateListApi(Resource): class ExternalApiTemplateListApi(Resource):
@api.doc("get_external_api_templates") @console_ns.doc("get_external_api_templates")
@api.doc(description="Get external knowledge API templates") @console_ns.doc(description="Get external knowledge API templates")
@api.doc( @console_ns.doc(
params={ params={
"page": "Page number (default: 1)", "page": "Page number (default: 1)",
"limit": "Number of items per page (default: 20)", "limit": "Number of items per page (default: 20)",
"keyword": "Search keyword", "keyword": "Search keyword",
} }
) )
@api.response(200, "External API templates retrieved successfully") @console_ns.response(200, "External API templates retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -95,11 +95,11 @@ class ExternalApiTemplateListApi(Resource):
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>") @console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
class ExternalApiTemplateApi(Resource): class ExternalApiTemplateApi(Resource):
@api.doc("get_external_api_template") @console_ns.doc("get_external_api_template")
@api.doc(description="Get external knowledge API template details") @console_ns.doc(description="Get external knowledge API template details")
@api.doc(params={"external_knowledge_api_id": "External knowledge API ID"}) @console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
@api.response(200, "External API template retrieved successfully") @console_ns.response(200, "External API template retrieved successfully")
@api.response(404, "Template not found") @console_ns.response(404, "Template not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -163,10 +163,10 @@ class ExternalApiTemplateApi(Resource):
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check") @console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
class ExternalApiUseCheckApi(Resource): class ExternalApiUseCheckApi(Resource):
@api.doc("check_external_api_usage") @console_ns.doc("check_external_api_usage")
@api.doc(description="Check if external knowledge API is being used") @console_ns.doc(description="Check if external knowledge API is being used")
@api.doc(params={"external_knowledge_api_id": "External knowledge API ID"}) @console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
@api.response(200, "Usage check completed successfully") @console_ns.response(200, "Usage check completed successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -181,10 +181,10 @@ class ExternalApiUseCheckApi(Resource):
@console_ns.route("/datasets/external") @console_ns.route("/datasets/external")
class ExternalDatasetCreateApi(Resource): class ExternalDatasetCreateApi(Resource):
@api.doc("create_external_dataset") @console_ns.doc("create_external_dataset")
@api.doc(description="Create external knowledge dataset") @console_ns.doc(description="Create external knowledge dataset")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"CreateExternalDatasetRequest", "CreateExternalDatasetRequest",
{ {
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"), "external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
@ -194,18 +194,16 @@ class ExternalDatasetCreateApi(Resource):
}, },
) )
) )
@api.response(201, "External dataset created successfully", dataset_detail_fields) @console_ns.response(201, "External dataset created successfully", dataset_detail_fields)
@api.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@api.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = ( parser = (
reqparse.RequestParser() reqparse.RequestParser()
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json") .add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
@ -241,11 +239,11 @@ class ExternalDatasetCreateApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/external-hit-testing") @console_ns.route("/datasets/<uuid:dataset_id>/external-hit-testing")
class ExternalKnowledgeHitTestingApi(Resource): class ExternalKnowledgeHitTestingApi(Resource):
@api.doc("test_external_knowledge_retrieval") @console_ns.doc("test_external_knowledge_retrieval")
@api.doc(description="Test external knowledge retrieval for dataset") @console_ns.doc(description="Test external knowledge retrieval for dataset")
@api.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"ExternalHitTestingRequest", "ExternalHitTestingRequest",
{ {
"query": fields.String(required=True, description="Query text for testing"), "query": fields.String(required=True, description="Query text for testing"),
@ -254,9 +252,9 @@ class ExternalKnowledgeHitTestingApi(Resource):
}, },
) )
) )
@api.response(200, "External hit testing completed successfully") @console_ns.response(200, "External hit testing completed successfully")
@api.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@api.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -299,10 +297,10 @@ class ExternalKnowledgeHitTestingApi(Resource):
@console_ns.route("/test/retrieval") @console_ns.route("/test/retrieval")
class BedrockRetrievalApi(Resource): class BedrockRetrievalApi(Resource):
# this api is only for internal testing # this api is only for internal testing
@api.doc("bedrock_retrieval_test") @console_ns.doc("bedrock_retrieval_test")
@api.doc(description="Bedrock retrieval test (internal use only)") @console_ns.doc(description="Bedrock retrieval test (internal use only)")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"BedrockRetrievalTestRequest", "BedrockRetrievalTestRequest",
{ {
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"), "retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
@ -311,7 +309,7 @@ class BedrockRetrievalApi(Resource):
}, },
) )
) )
@api.response(200, "Bedrock retrieval test completed") @console_ns.response(200, "Bedrock retrieval test completed")
def post(self): def post(self):
parser = ( parser = (
reqparse.RequestParser() reqparse.RequestParser()

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields from flask_restx import Resource, fields
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
@ -12,11 +12,11 @@ from libs.login import login_required
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing") @console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
class HitTestingApi(Resource, DatasetsHitTestingBase): class HitTestingApi(Resource, DatasetsHitTestingBase):
@api.doc("test_dataset_retrieval") @console_ns.doc("test_dataset_retrieval")
@api.doc(description="Test dataset knowledge retrieval") @console_ns.doc(description="Test dataset knowledge retrieval")
@api.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"HitTestingRequest", "HitTestingRequest",
{ {
"query": fields.String(required=True, description="Query text for testing"), "query": fields.String(required=True, description="Query text for testing"),
@ -26,9 +26,9 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
}, },
) )
) )
@api.response(200, "Hit testing completed successfully") @console_ns.response(200, "Hit testing completed successfully")
@api.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@api.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -3,7 +3,7 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config from configs import dify_config
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -130,7 +130,7 @@ parser_datasource = (
@console_ns.route("/auth/plugin/datasource/<path:provider_id>") @console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource): class DatasourceAuth(Resource):
@api.expect(parser_datasource) @console_ns.expect(parser_datasource)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -176,7 +176,7 @@ parser_datasource_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource): class DatasourceAuthDeleteApi(Resource):
@api.expect(parser_datasource_delete) @console_ns.expect(parser_datasource_delete)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -209,7 +209,7 @@ parser_datasource_update = (
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource): class DatasourceAuthUpdateApi(Resource):
@api.expect(parser_datasource_update) @console_ns.expect(parser_datasource_update)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -267,7 +267,7 @@ parser_datasource_custom = (
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource): class DatasourceAuthOauthCustomClient(Resource):
@api.expect(parser_datasource_custom) @console_ns.expect(parser_datasource_custom)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -306,7 +306,7 @@ parser_default = reqparse.RequestParser().add_argument("id", type=str, required=
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource): class DatasourceAuthDefaultApi(Resource):
@api.expect(parser_default) @console_ns.expect(parser_default)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -334,7 +334,7 @@ parser_update_name = (
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource): class DatasourceUpdateProviderNameApi(Resource):
@api.expect(parser_update_name) @console_ns.expect(parser_update_name)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,10 +1,10 @@
from flask_restx import ( # type: ignore from flask_restx import ( # type: ignore
Resource, # type: ignore Resource, # type: ignore
reqparse,
) )
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required from libs.login import current_user, login_required
@ -12,17 +12,21 @@ from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline import RagPipelineService
parser = ( DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json") class Parser(BaseModel):
.add_argument("credential_id", type=str, required=False, location="json") inputs: dict
) datasource_type: str
credential_id: str | None = None
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource): class DataSourceContentPreviewApi(Resource):
@api.expect(parser) @console_ns.expect(console_ns.models[Parser.__name__], validate=True)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -34,15 +38,10 @@ class DataSourceContentPreviewApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
args = parser.parse_args() args = Parser.model_validate(console_ns.payload)
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
inputs = args.inputs
datasource_type = args.datasource_type
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
preview_content = rag_pipeline_service.run_datasource_node_preview( preview_content = rag_pipeline_service.run_datasource_node_preview(
pipeline=pipeline, pipeline=pipeline,
@ -51,6 +50,6 @@ class DataSourceContentPreviewApi(Resource):
account=current_user, account=current_user,
datasource_type=datasource_type, datasource_type=datasource_type,
is_published=True, is_published=True,
credential_id=args.get("credential_id"), credential_id=args.credential_id,
) )
return preview_content, 200 return preview_content, 200

View File

@ -1,11 +1,11 @@
from flask_restx import Resource, marshal_with, reqparse # type: ignore from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
edit_permission_required,
setup_required, setup_required,
) )
from extensions.ext_database import db from extensions.ext_database import db
@ -21,12 +21,11 @@ class RagPipelineImportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields) @marshal_with(pipeline_import_fields)
def post(self): def post(self):
# Check user role first # Check user role first
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = ( parser = (
reqparse.RequestParser() reqparse.RequestParser()
@ -71,12 +70,10 @@ class RagPipelineImportConfirmApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields) @marshal_with(pipeline_import_fields)
def post(self, import_id): def post(self, import_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
# Check user role first
if not current_user.has_edit_permission:
raise Forbidden()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
@ -98,12 +95,9 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@login_required @login_required
@get_rag_pipeline @get_rag_pipeline
@account_initialization_required @account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_check_dependencies_fields) @marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = RagPipelineDslService(session) import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline) result = import_service.check_dependencies(pipeline=pipeline)
@ -117,11 +111,8 @@ class RagPipelineExportApi(Resource):
@login_required @login_required
@get_rag_pipeline @get_rag_pipeline
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Add include_secret params # Add include_secret params
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args") parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
args = parser.parse_args() args = parser.parse_args()

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
ConversationCompletedError, ConversationCompletedError,
DraftWorkflowNotExist, DraftWorkflowNotExist,
@ -153,7 +153,7 @@ parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource): class RagPipelineDraftRunIterationNodeApi(Resource):
@api.expect(parser_run) @console_ns.expect(parser_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -187,10 +187,11 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource): class RagPipelineDraftRunLoopNodeApi(Resource):
@api.expect(parser_run) @console_ns.expect(parser_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str): def post(self, pipeline: Pipeline, node_id: str):
""" """
@ -198,8 +199,6 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_run.parse_args() args = parser_run.parse_args()
@ -231,10 +230,11 @@ parser_draft_run = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource): class DraftRagPipelineRunApi(Resource):
@api.expect(parser_draft_run) @console_ns.expect(parser_draft_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline): def post(self, pipeline: Pipeline):
""" """
@ -242,8 +242,6 @@ class DraftRagPipelineRunApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_draft_run.parse_args() args = parser_draft_run.parse_args()
@ -275,10 +273,11 @@ parser_published_run = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource): class PublishedRagPipelineRunApi(Resource):
@api.expect(parser_published_run) @console_ns.expect(parser_published_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline): def post(self, pipeline: Pipeline):
""" """
@ -286,8 +285,6 @@ class PublishedRagPipelineRunApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_published_run.parse_args() args = parser_published_run.parse_args()
@ -400,10 +397,11 @@ parser_rag_run = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource): class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@api.expect(parser_rag_run) @console_ns.expect(parser_rag_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str): def post(self, pipeline: Pipeline, node_id: str):
""" """
@ -411,8 +409,6 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_rag_run.parse_args() args = parser_rag_run.parse_args()
@ -441,9 +437,10 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource): class RagPipelineDraftDatasourceNodeRunApi(Resource):
@api.expect(parser_rag_run) @console_ns.expect(parser_rag_run)
@setup_required @setup_required
@login_required @login_required
@edit_permission_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str): def post(self, pipeline: Pipeline, node_id: str):
@ -452,8 +449,6 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_rag_run.parse_args() args = parser_rag_run.parse_args()
@ -487,9 +482,10 @@ parser_run_api = reqparse.RequestParser().add_argument(
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource): class RagPipelineDraftNodeRunApi(Resource):
@api.expect(parser_run_api) @console_ns.expect(parser_run_api)
@setup_required @setup_required
@login_required @login_required
@edit_permission_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@marshal_with(workflow_run_node_execution_fields) @marshal_with(workflow_run_node_execution_fields)
@ -499,8 +495,6 @@ class RagPipelineDraftNodeRunApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_run_api.parse_args() args = parser_run_api.parse_args()
@ -523,6 +517,7 @@ class RagPipelineDraftNodeRunApi(Resource):
class RagPipelineTaskStopApi(Resource): class RagPipelineTaskStopApi(Resource):
@setup_required @setup_required
@login_required @login_required
@edit_permission_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline, task_id: str): def post(self, pipeline: Pipeline, task_id: str):
@ -531,8 +526,6 @@ class RagPipelineTaskStopApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
@ -544,6 +537,7 @@ class PublishedRagPipelineApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
@ -551,9 +545,6 @@ class PublishedRagPipelineApi(Resource):
Get published pipeline Get published pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
if not pipeline.is_published: if not pipeline.is_published:
return None return None
# fetch published workflow by pipeline # fetch published workflow by pipeline
@ -566,6 +557,7 @@ class PublishedRagPipelineApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline): def post(self, pipeline: Pipeline):
""" """
@ -573,9 +565,6 @@ class PublishedRagPipelineApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session: with Session(db.engine) as session:
pipeline = session.merge(pipeline) pipeline = session.merge(pipeline)
@ -602,16 +591,12 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get default block config Get default block config
""" """
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs # Get default block configs
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
return rag_pipeline_service.get_default_block_configs() return rag_pipeline_service.get_default_block_configs()
@ -622,20 +607,16 @@ parser_default = reqparse.RequestParser().add_argument("q", type=str, location="
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource): class DefaultRagPipelineBlockConfigApi(Resource):
@api.expect(parser_default) @console_ns.expect(parser_default)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def get(self, pipeline: Pipeline, block_type: str): def get(self, pipeline: Pipeline, block_type: str):
""" """
Get default block config Get default block config
""" """
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_default.parse_args() args = parser_default.parse_args()
q = args.get("q") q = args.get("q")
@ -663,10 +644,11 @@ parser_wf = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
class PublishedAllRagPipelineApi(Resource): class PublishedAllRagPipelineApi(Resource):
@api.expect(parser_wf) @console_ns.expect(parser_wf)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
@marshal_with(workflow_pagination_fields) @marshal_with(workflow_pagination_fields)
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
@ -674,8 +656,6 @@ class PublishedAllRagPipelineApi(Resource):
Get published workflows Get published workflows
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_wf.parse_args() args = parser_wf.parse_args()
page = args["page"] page = args["page"]
@ -716,10 +696,11 @@ parser_wf_id = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource): class RagPipelineByIdApi(Resource):
@api.expect(parser_wf_id) @console_ns.expect(parser_wf_id)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
def patch(self, pipeline: Pipeline, workflow_id: str): def patch(self, pipeline: Pipeline, workflow_id: str):
@ -728,8 +709,6 @@ class RagPipelineByIdApi(Resource):
""" """
# Check permission # Check permission
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_wf_id.parse_args() args = parser_wf_id.parse_args()
@ -775,7 +754,7 @@ parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, r
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource): class PublishedRagPipelineSecondStepApi(Resource):
@api.expect(parser_parameters) @console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -798,7 +777,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource): class PublishedRagPipelineFirstStepApi(Resource):
@api.expect(parser_parameters) @console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -821,7 +800,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource): class DraftRagPipelineFirstStepApi(Resource):
@api.expect(parser_parameters) @console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -844,7 +823,7 @@ class DraftRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource): class DraftRagPipelineSecondStepApi(Resource):
@api.expect(parser_parameters) @console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -875,7 +854,7 @@ parser_wf_run = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource): class RagPipelineWorkflowRunListApi(Resource):
@api.expect(parser_wf_run) @console_ns.expect(parser_wf_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -996,7 +975,7 @@ parser_var = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource): class RagPipelineDatasourceVariableApi(Resource):
@api.expect(parser_var) @console_ns.expect(parser_var)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required from libs.login import login_required
@ -9,10 +9,10 @@ from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusA
@console_ns.route("/website/crawl") @console_ns.route("/website/crawl")
class WebsiteCrawlApi(Resource): class WebsiteCrawlApi(Resource):
@api.doc("crawl_website") @console_ns.doc("crawl_website")
@api.doc(description="Crawl website content") @console_ns.doc(description="Crawl website content")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"WebsiteCrawlRequest", "WebsiteCrawlRequest",
{ {
"provider": fields.String( "provider": fields.String(
@ -25,8 +25,8 @@ class WebsiteCrawlApi(Resource):
}, },
) )
) )
@api.response(200, "Website crawl initiated successfully") @console_ns.response(200, "Website crawl initiated successfully")
@api.response(400, "Invalid crawl parameters") @console_ns.response(400, "Invalid crawl parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -62,12 +62,12 @@ class WebsiteCrawlApi(Resource):
@console_ns.route("/website/crawl/status/<string:job_id>") @console_ns.route("/website/crawl/status/<string:job_id>")
class WebsiteCrawlStatusApi(Resource): class WebsiteCrawlStatusApi(Resource):
@api.doc("get_crawl_status") @console_ns.doc("get_crawl_status")
@api.doc(description="Get website crawl status") @console_ns.doc(description="Get website crawl status")
@api.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"}) @console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
@api.response(200, "Crawl status retrieved successfully") @console_ns.response(200, "Crawl status retrieved successfully")
@api.response(404, "Crawl job not found") @console_ns.response(404, "Crawl job not found")
@api.response(400, "Invalid provider") @console_ns.response(400, "Invalid provider")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,18 +1,19 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from controllers.console.datasets.error import PipelineNotFoundError from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from models.dataset import Pipeline from models.dataset import Pipeline
P = ParamSpec("P")
R = TypeVar("R")
def get_rag_pipeline(
view: Callable | None = None, def get_rag_pipeline(view_func: Callable[P, R]):
):
def decorator(view_func):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args, **kwargs): def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("pipeline_id"): if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters") raise ValueError("missing pipeline_id in path parameters")
@ -37,8 +38,3 @@ def get_rag_pipeline(
return view_func(*args, **kwargs) return view_func(*args, **kwargs)
return decorated_view return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages from constants.languages import languages
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField from libs.helper import AppIconUrlField
from libs.login import current_user, login_required from libs.login import current_user, login_required
@ -40,7 +40,7 @@ parser_apps = reqparse.RequestParser().add_argument("language", type=str, locati
@console_ns.route("/explore/apps") @console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource): class RecommendedAppListApi(Resource):
@api.expect(parser_apps) @console_ns.expect(parser_apps)
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(recommended_app_list_fields) @marshal_with(recommended_app_list_fields)

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
@ -12,15 +12,17 @@ from services.code_based_extension_service import CodeBasedExtensionService
@console_ns.route("/code-based-extension") @console_ns.route("/code-based-extension")
class CodeBasedExtensionAPI(Resource): class CodeBasedExtensionAPI(Resource):
@api.doc("get_code_based_extension") @console_ns.doc("get_code_based_extension")
@api.doc(description="Get code-based extension data by module name") @console_ns.doc(description="Get code-based extension data by module name")
@api.expect( @console_ns.expect(
api.parser().add_argument("module", type=str, required=True, location="args", help="Extension module name") console_ns.parser().add_argument(
"module", type=str, required=True, location="args", help="Extension module name"
) )
@api.response( )
@console_ns.response(
200, 200,
"Success", "Success",
api.model( console_ns.model(
"CodeBasedExtensionResponse", "CodeBasedExtensionResponse",
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")}, {"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
), ),
@ -37,9 +39,9 @@ class CodeBasedExtensionAPI(Resource):
@console_ns.route("/api-based-extension") @console_ns.route("/api-based-extension")
class APIBasedExtensionAPI(Resource): class APIBasedExtensionAPI(Resource):
@api.doc("get_api_based_extensions") @console_ns.doc("get_api_based_extensions")
@api.doc(description="Get all API-based extensions for current tenant") @console_ns.doc(description="Get all API-based extensions for current tenant")
@api.response(200, "Success", fields.List(fields.Nested(api_based_extension_fields))) @console_ns.response(200, "Success", fields.List(fields.Nested(api_based_extension_fields)))
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -48,10 +50,10 @@ class APIBasedExtensionAPI(Resource):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@api.doc("create_api_based_extension") @console_ns.doc("create_api_based_extension")
@api.doc(description="Create a new API-based extension") @console_ns.doc(description="Create a new API-based extension")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"CreateAPIBasedExtensionRequest", "CreateAPIBasedExtensionRequest",
{ {
"name": fields.String(required=True, description="Extension name"), "name": fields.String(required=True, description="Extension name"),
@ -60,13 +62,13 @@ class APIBasedExtensionAPI(Resource):
}, },
) )
) )
@api.response(201, "Extension created successfully", api_based_extension_fields) @console_ns.response(201, "Extension created successfully", api_based_extension_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_fields) @marshal_with(api_based_extension_fields)
def post(self): def post(self):
args = api.payload args = console_ns.payload
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension( extension_data = APIBasedExtension(
@ -81,10 +83,10 @@ class APIBasedExtensionAPI(Resource):
@console_ns.route("/api-based-extension/<uuid:id>") @console_ns.route("/api-based-extension/<uuid:id>")
class APIBasedExtensionDetailAPI(Resource): class APIBasedExtensionDetailAPI(Resource):
@api.doc("get_api_based_extension") @console_ns.doc("get_api_based_extension")
@api.doc(description="Get API-based extension by ID") @console_ns.doc(description="Get API-based extension by ID")
@api.doc(params={"id": "Extension ID"}) @console_ns.doc(params={"id": "Extension ID"})
@api.response(200, "Success", api_based_extension_fields) @console_ns.response(200, "Success", api_based_extension_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -95,11 +97,11 @@ class APIBasedExtensionDetailAPI(Resource):
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
@api.doc("update_api_based_extension") @console_ns.doc("update_api_based_extension")
@api.doc(description="Update API-based extension") @console_ns.doc(description="Update API-based extension")
@api.doc(params={"id": "Extension ID"}) @console_ns.doc(params={"id": "Extension ID"})
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"UpdateAPIBasedExtensionRequest", "UpdateAPIBasedExtensionRequest",
{ {
"name": fields.String(required=True, description="Extension name"), "name": fields.String(required=True, description="Extension name"),
@ -108,7 +110,7 @@ class APIBasedExtensionDetailAPI(Resource):
}, },
) )
) )
@api.response(200, "Extension updated successfully", api_based_extension_fields) @console_ns.response(200, "Extension updated successfully", api_based_extension_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -119,7 +121,7 @@ class APIBasedExtensionDetailAPI(Resource):
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
args = api.payload args = console_ns.payload
extension_data_from_db.name = args["name"] extension_data_from_db.name = args["name"]
extension_data_from_db.api_endpoint = args["api_endpoint"] extension_data_from_db.api_endpoint = args["api_endpoint"]
@ -129,10 +131,10 @@ class APIBasedExtensionDetailAPI(Resource):
return APIBasedExtensionService.save(extension_data_from_db) return APIBasedExtensionService.save(extension_data_from_db)
@api.doc("delete_api_based_extension") @console_ns.doc("delete_api_based_extension")
@api.doc(description="Delete API-based extension") @console_ns.doc(description="Delete API-based extension")
@api.doc(params={"id": "Extension ID"}) @console_ns.doc(params={"id": "Extension ID"})
@api.response(204, "Extension deleted successfully") @console_ns.response(204, "Extension deleted successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -3,18 +3,18 @@ from flask_restx import Resource, fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.feature_service import FeatureService from services.feature_service import FeatureService
from . import api, console_ns from . import console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required from .wraps import account_initialization_required, cloud_utm_record, setup_required
@console_ns.route("/features") @console_ns.route("/features")
class FeatureApi(Resource): class FeatureApi(Resource):
@api.doc("get_tenant_features") @console_ns.doc("get_tenant_features")
@api.doc(description="Get feature configuration for current tenant") @console_ns.doc(description="Get feature configuration for current tenant")
@api.response( @console_ns.response(
200, 200,
"Success", "Success",
api.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}), console_ns.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
) )
@setup_required @setup_required
@login_required @login_required
@ -29,12 +29,14 @@ class FeatureApi(Resource):
@console_ns.route("/system-features") @console_ns.route("/system-features")
class SystemFeatureApi(Resource): class SystemFeatureApi(Resource):
@api.doc("get_system_features") @console_ns.doc("get_system_features")
@api.doc(description="Get system-wide feature configuration") @console_ns.doc(description="Get system-wide feature configuration")
@api.response( @console_ns.response(
200, 200,
"Success", "Success",
api.model("SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}), console_ns.model(
"SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}
),
) )
def get(self): def get(self):
"""Get system-wide feature configuration""" """Get system-wide feature configuration"""

View File

@ -11,19 +11,19 @@ from libs.helper import StrLen
from models.model import DifySetup from models.model import DifySetup
from services.account_service import TenantService from services.account_service import TenantService
from . import api, console_ns from . import console_ns
from .error import AlreadySetupError, InitValidateFailedError from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted from .wraps import only_edition_self_hosted
@console_ns.route("/init") @console_ns.route("/init")
class InitValidateAPI(Resource): class InitValidateAPI(Resource):
@api.doc("get_init_status") @console_ns.doc("get_init_status")
@api.doc(description="Get initialization validation status") @console_ns.doc(description="Get initialization validation status")
@api.response( @console_ns.response(
200, 200,
"Success", "Success",
model=api.model( model=console_ns.model(
"InitStatusResponse", "InitStatusResponse",
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])}, {"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
), ),
@ -35,20 +35,20 @@ class InitValidateAPI(Resource):
return {"status": "finished"} return {"status": "finished"}
return {"status": "not_started"} return {"status": "not_started"}
@api.doc("validate_init_password") @console_ns.doc("validate_init_password")
@api.doc(description="Validate initialization password for self-hosted edition") @console_ns.doc(description="Validate initialization password for self-hosted edition")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"InitValidateRequest", "InitValidateRequest",
{"password": fields.String(required=True, description="Initialization password", max_length=30)}, {"password": fields.String(required=True, description="Initialization password", max_length=30)},
) )
) )
@api.response( @console_ns.response(
201, 201,
"Success", "Success",
model=api.model("InitValidateResponse", {"result": fields.String(description="Operation result")}), model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
) )
@api.response(400, "Already setup or validation failed") @console_ns.response(400, "Already setup or validation failed")
@only_edition_self_hosted @only_edition_self_hosted
def post(self): def post(self):
"""Validate initialization password""" """Validate initialization password"""

View File

@ -1,16 +1,16 @@
from flask_restx import Resource, fields from flask_restx import Resource, fields
from . import api, console_ns from . import console_ns
@console_ns.route("/ping") @console_ns.route("/ping")
class PingApi(Resource): class PingApi(Resource):
@api.doc("health_check") @console_ns.doc("health_check")
@api.doc(description="Health check endpoint for connection testing") @console_ns.doc(description="Health check endpoint for connection testing")
@api.response( @console_ns.response(
200, 200,
"Success", "Success",
api.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}), console_ns.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
) )
def get(self): def get(self):
"""Health check endpoint for connection testing""" """Health check endpoint for connection testing"""

View File

@ -10,7 +10,6 @@ from controllers.common.errors import (
RemoteFileUploadError, RemoteFileUploadError,
UnsupportedFileTypeError, UnsupportedFileTypeError,
) )
from controllers.console import api
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db from extensions.ext_database import db
@ -42,7 +41,7 @@ parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=
@console_ns.route("/remote-files/upload") @console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource): class RemoteFileUploadApi(Resource):
@api.expect(parser_upload) @console_ns.expect(parser_upload)
@marshal_with(file_fields_with_signed_url) @marshal_with(file_fields_with_signed_url)
def post(self): def post(self):
args = parser_upload.parse_args() args = parser_upload.parse_args()

View File

@ -7,7 +7,7 @@ from libs.password import valid_password
from models.model import DifySetup, db from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
from . import api, console_ns from . import console_ns
from .error import AlreadySetupError, NotInitValidateError from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted from .wraps import only_edition_self_hosted
@ -15,12 +15,12 @@ from .wraps import only_edition_self_hosted
@console_ns.route("/setup") @console_ns.route("/setup")
class SetupApi(Resource): class SetupApi(Resource):
@api.doc("get_setup_status") @console_ns.doc("get_setup_status")
@api.doc(description="Get system setup status") @console_ns.doc(description="Get system setup status")
@api.response( @console_ns.response(
200, 200,
"Success", "Success",
api.model( console_ns.model(
"SetupStatusResponse", "SetupStatusResponse",
{ {
"step": fields.String(description="Setup step status", enum=["not_started", "finished"]), "step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
@ -40,10 +40,10 @@ class SetupApi(Resource):
return {"step": "not_started"} return {"step": "not_started"}
return {"step": "finished"} return {"step": "finished"}
@api.doc("setup_system") @console_ns.doc("setup_system")
@api.doc(description="Initialize system setup with admin account") @console_ns.doc(description="Initialize system setup with admin account")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"SetupRequest", "SetupRequest",
{ {
"email": fields.String(required=True, description="Admin email address"), "email": fields.String(required=True, description="Admin email address"),
@ -53,8 +53,10 @@ class SetupApi(Resource):
}, },
) )
) )
@api.response(201, "Success", api.model("SetupResponse", {"result": fields.String(description="Setup result")})) @console_ns.response(
@api.response(400, "Already setup or validation failed") 201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
)
@console_ns.response(400, "Already setup or validation failed")
@only_edition_self_hosted @only_edition_self_hosted
def post(self): def post(self):
"""Initialize system setup with admin account""" """Initialize system setup with admin account"""

View File

@ -2,8 +2,8 @@ from flask import request
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.model import Tag from models.model import Tag
@ -43,7 +43,7 @@ class TagListApi(Resource):
return tags, 200 return tags, 200
@api.expect(parser_tags) @console_ns.expect(parser_tags)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -68,7 +68,7 @@ parser_tag_id = reqparse.RequestParser().add_argument(
@console_ns.route("/tags/<uuid:tag_id>") @console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource): class TagUpdateDeleteApi(Resource):
@api.expect(parser_tag_id) @console_ns.expect(parser_tag_id)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -91,12 +91,9 @@ class TagUpdateDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, tag_id): def delete(self, tag_id):
current_user, _ = current_account_with_tenant()
tag_id = str(tag_id) tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
TagService.delete_tag(tag_id) TagService.delete_tag(tag_id)
@ -113,7 +110,7 @@ parser_create = (
@console_ns.route("/tag-bindings/create") @console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource): class TagBindingCreateApi(Resource):
@api.expect(parser_create) @console_ns.expect(parser_create)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -139,7 +136,7 @@ parser_remove = (
@console_ns.route("/tag-bindings/remove") @console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource): class TagBindingDeleteApi(Resource):
@api.expect(parser_remove) @console_ns.expect(parser_remove)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -7,7 +7,7 @@ from packaging import version
from configs import dify_config from configs import dify_config
from . import api, console_ns from . import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,13 +18,13 @@ parser = reqparse.RequestParser().add_argument(
@console_ns.route("/version") @console_ns.route("/version")
class VersionApi(Resource): class VersionApi(Resource):
@api.doc("check_version_update") @console_ns.doc("check_version_update")
@api.doc(description="Check for application version updates") @console_ns.doc(description="Check for application version updates")
@api.expect(parser) @console_ns.expect(parser)
@api.response( @console_ns.response(
200, 200,
"Success", "Success",
api.model( console_ns.model(
"VersionResponse", "VersionResponse",
{ {
"version": fields.String(description="Latest version number"), "version": fields.String(description="Latest version number"),

View File

@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
EmailAlreadyInUseError, EmailAlreadyInUseError,
EmailChangeLimitError, EmailChangeLimitError,
@ -55,7 +55,7 @@ def _init_parser():
@console_ns.route("/account/init") @console_ns.route("/account/init")
class AccountInitApi(Resource): class AccountInitApi(Resource):
@api.expect(_init_parser()) @console_ns.expect(_init_parser())
@setup_required @setup_required
@login_required @login_required
def post(self): def post(self):
@ -115,7 +115,7 @@ parser_name = reqparse.RequestParser().add_argument("name", type=str, required=T
@console_ns.route("/account/name") @console_ns.route("/account/name")
class AccountNameApi(Resource): class AccountNameApi(Resource):
@api.expect(parser_name) @console_ns.expect(parser_name)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -138,7 +138,7 @@ parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, requir
@console_ns.route("/account/avatar") @console_ns.route("/account/avatar")
class AccountAvatarApi(Resource): class AccountAvatarApi(Resource):
@api.expect(parser_avatar) @console_ns.expect(parser_avatar)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -159,7 +159,7 @@ parser_interface = reqparse.RequestParser().add_argument(
@console_ns.route("/account/interface-language") @console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource): class AccountInterfaceLanguageApi(Resource):
@api.expect(parser_interface) @console_ns.expect(parser_interface)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -180,7 +180,7 @@ parser_theme = reqparse.RequestParser().add_argument(
@console_ns.route("/account/interface-theme") @console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource): class AccountInterfaceThemeApi(Resource):
@api.expect(parser_theme) @console_ns.expect(parser_theme)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -199,7 +199,7 @@ parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, re
@console_ns.route("/account/timezone") @console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource): class AccountTimezoneApi(Resource):
@api.expect(parser_timezone) @console_ns.expect(parser_timezone)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -227,7 +227,7 @@ parser_pw = (
@console_ns.route("/account/password") @console_ns.route("/account/password")
class AccountPasswordApi(Resource): class AccountPasswordApi(Resource):
@api.expect(parser_pw) @console_ns.expect(parser_pw)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -325,7 +325,7 @@ parser_delete = (
@console_ns.route("/account/delete") @console_ns.route("/account/delete")
class AccountDeleteApi(Resource): class AccountDeleteApi(Resource):
@api.expect(parser_delete) @console_ns.expect(parser_delete)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -351,7 +351,7 @@ parser_feedback = (
@console_ns.route("/account/delete/feedback") @console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource): class AccountDeleteUpdateFeedbackApi(Resource):
@api.expect(parser_feedback) @console_ns.expect(parser_feedback)
@setup_required @setup_required
def post(self): def post(self):
args = parser_feedback.parse_args() args = parser_feedback.parse_args()
@ -396,7 +396,7 @@ class EducationApi(Resource):
"allow_refresh": fields.Boolean, "allow_refresh": fields.Boolean,
} }
@api.expect(parser_edu) @console_ns.expect(parser_edu)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -441,7 +441,7 @@ class EducationAutoCompleteApi(Resource):
"has_next": fields.Boolean, "has_next": fields.Boolean,
} }
@api.expect(parser_autocomplete) @console_ns.expect(parser_autocomplete)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -465,7 +465,7 @@ parser_change_email = (
@console_ns.route("/account/change-email") @console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource): class ChangeEmailSendEmailApi(Resource):
@api.expect(parser_change_email) @console_ns.expect(parser_change_email)
@enable_change_email @enable_change_email
@setup_required @setup_required
@login_required @login_required
@ -517,7 +517,7 @@ parser_validity = (
@console_ns.route("/account/change-email/validity") @console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource): class ChangeEmailCheckApi(Resource):
@api.expect(parser_validity) @console_ns.expect(parser_validity)
@enable_change_email @enable_change_email
@setup_required @setup_required
@login_required @login_required
@ -563,7 +563,7 @@ parser_reset = (
@console_ns.route("/account/change-email/reset") @console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource): class ChangeEmailResetApi(Resource):
@api.expect(parser_reset) @console_ns.expect(parser_reset)
@enable_change_email @enable_change_email
@setup_required @setup_required
@login_required @login_required
@ -603,7 +603,7 @@ parser_check = reqparse.RequestParser().add_argument("email", type=email, requir
@console_ns.route("/account/change-email/check-email-unique") @console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource): class CheckEmailUnique(Resource):
@api.expect(parser_check) @console_ns.expect(parser_check)
@setup_required @setup_required
def post(self): def post(self):
args = parser_check.parse_args() args = parser_check.parse_args()

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields from flask_restx import Resource, fields
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
@ -9,9 +9,9 @@ from services.agent_service import AgentService
@console_ns.route("/workspaces/current/agent-providers") @console_ns.route("/workspaces/current/agent-providers")
class AgentProviderListApi(Resource): class AgentProviderListApi(Resource):
@api.doc("list_agent_providers") @console_ns.doc("list_agent_providers")
@api.doc(description="Get list of available agent providers") @console_ns.doc(description="Get list of available agent providers")
@api.response( @console_ns.response(
200, 200,
"Success", "Success",
fields.List(fields.Raw(description="Agent provider information")), fields.List(fields.Raw(description="Agent provider information")),
@ -31,10 +31,10 @@ class AgentProviderListApi(Resource):
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>") @console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
class AgentProviderApi(Resource): class AgentProviderApi(Resource):
@api.doc("get_agent_provider") @console_ns.doc("get_agent_provider")
@api.doc(description="Get specific agent provider details") @console_ns.doc(description="Get specific agent provider details")
@api.doc(params={"provider_name": "Agent provider name"}) @console_ns.doc(params={"provider_name": "Agent provider name"})
@api.response( @console_ns.response(
200, 200,
"Success", "Success",
fields.Raw(description="Agent provider details"), fields.Raw(description="Agent provider details"),

View File

@ -1,8 +1,7 @@
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
@ -11,10 +10,10 @@ from services.plugin.endpoint_service import EndpointService
@console_ns.route("/workspaces/current/endpoints/create") @console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource): class EndpointCreateApi(Resource):
@api.doc("create_endpoint") @console_ns.doc("create_endpoint")
@api.doc(description="Create a new plugin endpoint") @console_ns.doc(description="Create a new plugin endpoint")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"EndpointCreateRequest", "EndpointCreateRequest",
{ {
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"), "plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
@ -23,19 +22,18 @@ class EndpointCreateApi(Resource):
}, },
) )
) )
@api.response( @console_ns.response(
200, 200,
"Endpoint created successfully", "Endpoint created successfully",
api.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}), console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
) )
@api.response(403, "Admin privileges required") @console_ns.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
parser = ( parser = (
reqparse.RequestParser() reqparse.RequestParser()
@ -65,17 +63,19 @@ class EndpointCreateApi(Resource):
@console_ns.route("/workspaces/current/endpoints/list") @console_ns.route("/workspaces/current/endpoints/list")
class EndpointListApi(Resource): class EndpointListApi(Resource):
@api.doc("list_endpoints") @console_ns.doc("list_endpoints")
@api.doc(description="List plugin endpoints with pagination") @console_ns.doc(description="List plugin endpoints with pagination")
@api.expect( @console_ns.expect(
api.parser() console_ns.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number") .add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size") .add_argument("page_size", type=int, required=True, location="args", help="Page size")
) )
@api.response( @console_ns.response(
200, 200,
"Success", "Success",
api.model("EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}), console_ns.model(
"EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
) )
@setup_required @setup_required
@login_required @login_required
@ -107,18 +107,18 @@ class EndpointListApi(Resource):
@console_ns.route("/workspaces/current/endpoints/list/plugin") @console_ns.route("/workspaces/current/endpoints/list/plugin")
class EndpointListForSinglePluginApi(Resource): class EndpointListForSinglePluginApi(Resource):
@api.doc("list_plugin_endpoints") @console_ns.doc("list_plugin_endpoints")
@api.doc(description="List endpoints for a specific plugin") @console_ns.doc(description="List endpoints for a specific plugin")
@api.expect( @console_ns.expect(
api.parser() console_ns.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number") .add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size") .add_argument("page_size", type=int, required=True, location="args", help="Page size")
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID") .add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
) )
@api.response( @console_ns.response(
200, 200,
"Success", "Success",
api.model( console_ns.model(
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} "PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
), ),
) )
@ -155,19 +155,22 @@ class EndpointListForSinglePluginApi(Resource):
@console_ns.route("/workspaces/current/endpoints/delete") @console_ns.route("/workspaces/current/endpoints/delete")
class EndpointDeleteApi(Resource): class EndpointDeleteApi(Resource):
@api.doc("delete_endpoint") @console_ns.doc("delete_endpoint")
@api.doc(description="Delete a plugin endpoint") @console_ns.doc(description="Delete a plugin endpoint")
@api.expect( @console_ns.expect(
api.model("EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) console_ns.model(
"EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
) )
@api.response( )
@console_ns.response(
200, 200,
"Endpoint deleted successfully", "Endpoint deleted successfully",
api.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}), console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
) )
@api.response(403, "Admin privileges required") @console_ns.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
@ -175,9 +178,6 @@ class EndpointDeleteApi(Resource):
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True) parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()
if not user.is_admin_or_owner:
raise Forbidden()
endpoint_id = args["endpoint_id"] endpoint_id = args["endpoint_id"]
return { return {
@ -187,10 +187,10 @@ class EndpointDeleteApi(Resource):
@console_ns.route("/workspaces/current/endpoints/update") @console_ns.route("/workspaces/current/endpoints/update")
class EndpointUpdateApi(Resource): class EndpointUpdateApi(Resource):
@api.doc("update_endpoint") @console_ns.doc("update_endpoint")
@api.doc(description="Update a plugin endpoint") @console_ns.doc(description="Update a plugin endpoint")
@api.expect( @console_ns.expect(
api.model( console_ns.model(
"EndpointUpdateRequest", "EndpointUpdateRequest",
{ {
"endpoint_id": fields.String(required=True, description="Endpoint ID"), "endpoint_id": fields.String(required=True, description="Endpoint ID"),
@ -199,14 +199,15 @@ class EndpointUpdateApi(Resource):
}, },
) )
) )
@api.response( @console_ns.response(
200, 200,
"Endpoint updated successfully", "Endpoint updated successfully",
api.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}), console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
) )
@api.response(403, "Admin privileges required") @console_ns.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
@ -223,9 +224,6 @@ class EndpointUpdateApi(Resource):
settings = args["settings"] settings = args["settings"]
name = args["name"] name = args["name"]
if not user.is_admin_or_owner:
raise Forbidden()
return { return {
"success": EndpointService.update_endpoint( "success": EndpointService.update_endpoint(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -239,19 +237,22 @@ class EndpointUpdateApi(Resource):
@console_ns.route("/workspaces/current/endpoints/enable") @console_ns.route("/workspaces/current/endpoints/enable")
class EndpointEnableApi(Resource): class EndpointEnableApi(Resource):
@api.doc("enable_endpoint") @console_ns.doc("enable_endpoint")
@api.doc(description="Enable a plugin endpoint") @console_ns.doc(description="Enable a plugin endpoint")
@api.expect( @console_ns.expect(
api.model("EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) console_ns.model(
"EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
) )
@api.response( )
@console_ns.response(
200, 200,
"Endpoint enabled successfully", "Endpoint enabled successfully",
api.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}), console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
) )
@api.response(403, "Admin privileges required") @console_ns.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
@ -261,9 +262,6 @@ class EndpointEnableApi(Resource):
endpoint_id = args["endpoint_id"] endpoint_id = args["endpoint_id"]
if not user.is_admin_or_owner:
raise Forbidden()
return { return {
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id) "success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
} }
@ -271,19 +269,22 @@ class EndpointEnableApi(Resource):
@console_ns.route("/workspaces/current/endpoints/disable") @console_ns.route("/workspaces/current/endpoints/disable")
class EndpointDisableApi(Resource): class EndpointDisableApi(Resource):
@api.doc("disable_endpoint") @console_ns.doc("disable_endpoint")
@api.doc(description="Disable a plugin endpoint") @console_ns.doc(description="Disable a plugin endpoint")
@api.expect( @console_ns.expect(
api.model("EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) console_ns.model(
"EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
) )
@api.response( )
@console_ns.response(
200, 200,
"Endpoint disabled successfully", "Endpoint disabled successfully",
api.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}), console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
) )
@api.response(403, "Admin privileges required") @console_ns.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
@ -293,9 +294,6 @@ class EndpointDisableApi(Resource):
endpoint_id = args["endpoint_id"] endpoint_id = args["endpoint_id"]
if not user.is_admin_or_owner:
raise Forbidden()
return { return {
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id) "success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
} }

View File

@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse
import services import services
from configs import dify_config from configs import dify_config
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
CannotTransferOwnerToSelfError, CannotTransferOwnerToSelfError,
EmailCodeError, EmailCodeError,
@ -60,7 +60,7 @@ parser_invite = (
class MemberInviteEmailApi(Resource): class MemberInviteEmailApi(Resource):
"""Invite a new member by email.""" """Invite a new member by email."""
@api.expect(parser_invite) @console_ns.expect(parser_invite)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -153,7 +153,7 @@ parser_update = reqparse.RequestParser().add_argument("role", type=str, required
class MemberUpdateRoleApi(Resource): class MemberUpdateRoleApi(Resource):
"""Update member role.""" """Update member role."""
@api.expect(parser_update) @console_ns.expect(parser_update)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -204,7 +204,7 @@ parser_send = reqparse.RequestParser().add_argument("language", type=str, requir
class SendOwnerTransferEmailApi(Resource): class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email.""" """Send owner transfer email."""
@api.expect(parser_send) @console_ns.expect(parser_send)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -247,7 +247,7 @@ parser_owner = (
@console_ns.route("/workspaces/current/members/owner-transfer-check") @console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource): class OwnerTransferCheckApi(Resource):
@api.expect(parser_owner) @console_ns.expect(parser_owner)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -295,7 +295,7 @@ parser_owner_transfer = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer") @console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
class OwnerTransfer(Resource): class OwnerTransfer(Resource):
@api.expect(parser_owner_transfer) @console_ns.expect(parser_owner_transfer)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -2,10 +2,9 @@ import io
from flask import send_file from flask import send_file
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -26,7 +25,7 @@ parser_model = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/model-providers") @console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource): class ModelProviderListApi(Resource):
@api.expect(parser_model) @console_ns.expect(parser_model)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -65,7 +64,7 @@ parser_delete_cred = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials") @console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
class ModelProviderCredentialApi(Resource): class ModelProviderCredentialApi(Resource):
@api.expect(parser_cred) @console_ns.expect(parser_cred)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -82,15 +81,13 @@ class ModelProviderCredentialApi(Resource):
return {"credentials": credentials} return {"credentials": credentials}
@api.expect(parser_post_cred) @console_ns.expect(parser_post_cred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_post_cred.parse_args() args = parser_post_cred.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -107,14 +104,13 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"}, 201 return {"result": "success"}, 201
@api.expect(parser_put_cred) @console_ns.expect(parser_put_cred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def put(self, provider: str): def put(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_put_cred.parse_args() args = parser_put_cred.parse_args()
@ -133,15 +129,13 @@ class ModelProviderCredentialApi(Resource):
return {"result": "success"} return {"result": "success"}
@api.expect(parser_delete_cred) @console_ns.expect(parser_delete_cred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_delete_cred.parse_args() args = parser_delete_cred.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -159,14 +153,13 @@ parser_switch = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch") @console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
class ModelProviderCredentialSwitchApi(Resource): class ModelProviderCredentialSwitchApi(Resource):
@api.expect(parser_switch) @console_ns.expect(parser_switch)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_switch.parse_args() args = parser_switch.parse_args()
service = ModelProviderService() service = ModelProviderService()
@ -185,7 +178,7 @@ parser_validate = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate") @console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
class ModelProviderValidateApi(Resource): class ModelProviderValidateApi(Resource):
@api.expect(parser_validate) @console_ns.expect(parser_validate)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -247,14 +240,13 @@ parser_preferred = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type") @console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource): class PreferredProviderTypeUpdateApi(Resource):
@api.expect(parser_preferred) @console_ns.expect(parser_preferred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_tenant_id tenant_id = current_tenant_id

View File

@ -1,10 +1,9 @@
import logging import logging
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -31,7 +30,7 @@ parser_post_default = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/default-model") @console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource): class DefaultModelApi(Resource):
@api.expect(parser_get_default) @console_ns.expect(parser_get_default)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -47,15 +46,13 @@ class DefaultModelApi(Resource):
return jsonable_encoder({"data": default_model_entity}) return jsonable_encoder({"data": default_model_entity})
@api.expect(parser_post_default) @console_ns.expect(parser_post_default)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
current_user, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_post_default.parse_args() args = parser_post_default.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -130,16 +127,14 @@ class ModelProviderModelApi(Resource):
return jsonable_encoder({"data": models}) return jsonable_encoder({"data": models})
@api.expect(parser_post_models) @console_ns.expect(parser_post_models)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
# To save the model's load balance configs # To save the model's load balance configs
current_user, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_post_models.parse_args() args = parser_post_models.parse_args()
if args.get("config_from", "") == "custom-model": if args.get("config_from", "") == "custom-model":
@ -178,15 +173,13 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
@api.expect(parser_delete_models) @console_ns.expect(parser_delete_models)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
current_user, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_delete_models.parse_args() args = parser_delete_models.parse_args()
@ -260,7 +253,7 @@ parser_delete_cred = (
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials") @console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
class ModelProviderModelCredentialApi(Resource): class ModelProviderModelCredentialApi(Resource):
@api.expect(parser_get_credentials) @console_ns.expect(parser_get_credentials)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -311,15 +304,13 @@ class ModelProviderModelCredentialApi(Resource):
} }
) )
@api.expect(parser_post_cred) @console_ns.expect(parser_post_cred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
current_user, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_post_cred.parse_args() args = parser_post_cred.parse_args()
@ -345,16 +336,13 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"}, 201 return {"result": "success"}, 201
@api.expect(parser_put_cred) @console_ns.expect(parser_put_cred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def put(self, provider: str): def put(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_put_cred.parse_args() args = parser_put_cred.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -374,15 +362,13 @@ class ModelProviderModelCredentialApi(Resource):
return {"result": "success"} return {"result": "success"}
@api.expect(parser_delete_cred) @console_ns.expect(parser_delete_cred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_delete_cred.parse_args() args = parser_delete_cred.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -414,15 +400,14 @@ parser_switch = (
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch") @console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource): class ModelProviderModelCredentialSwitchApi(Resource):
@api.expect(parser_switch) @console_ns.expect(parser_switch)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_switch.parse_args() args = parser_switch.parse_args()
service = ModelProviderService() service = ModelProviderService()
@ -454,7 +439,7 @@ parser_model_enable_disable = (
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable" "/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
) )
class ModelProviderModelEnableApi(Resource): class ModelProviderModelEnableApi(Resource):
@api.expect(parser_model_enable_disable) @console_ns.expect(parser_model_enable_disable)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -475,7 +460,7 @@ class ModelProviderModelEnableApi(Resource):
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable" "/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
) )
class ModelProviderModelDisableApi(Resource): class ModelProviderModelDisableApi(Resource):
@api.expect(parser_model_enable_disable) @console_ns.expect(parser_model_enable_disable)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -509,7 +494,7 @@ parser_validate = (
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate") @console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
class ModelProviderModelValidateApi(Resource): class ModelProviderModelValidateApi(Resource):
@api.expect(parser_validate) @console_ns.expect(parser_validate)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -550,7 +535,7 @@ parser_parameter = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules") @console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource): class ModelProviderModelParameterRuleApi(Resource):
@api.expect(parser_parameter) @console_ns.expect(parser_parameter)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -5,9 +5,9 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
@ -46,7 +46,7 @@ parser_list = (
@console_ns.route("/workspaces/current/plugin/list") @console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource): class PluginListApi(Resource):
@api.expect(parser_list) @console_ns.expect(parser_list)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -66,7 +66,7 @@ parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, r
@console_ns.route("/workspaces/current/plugin/list/latest-versions") @console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource): class PluginListLatestVersionsApi(Resource):
@api.expect(parser_latest) @console_ns.expect(parser_latest)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -86,7 +86,7 @@ parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, requ
@console_ns.route("/workspaces/current/plugin/list/installations/ids") @console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource): class PluginListInstallationsFromIdsApi(Resource):
@api.expect(parser_ids) @console_ns.expect(parser_ids)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -112,7 +112,7 @@ parser_icon = (
@console_ns.route("/workspaces/current/plugin/icon") @console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource): class PluginIconApi(Resource):
@api.expect(parser_icon) @console_ns.expect(parser_icon)
@setup_required @setup_required
def get(self): def get(self):
args = parser_icon.parse_args() args = parser_icon.parse_args()
@ -132,9 +132,11 @@ class PluginAssetApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
req = reqparse.RequestParser() req = (
req.add_argument("plugin_unique_identifier", type=str, required=True, location="args") reqparse.RequestParser()
req.add_argument("file_name", type=str, required=True, location="args") .add_argument("plugin_unique_identifier", type=str, required=True, location="args")
.add_argument("file_name", type=str, required=True, location="args")
)
args = req.parse_args() args = req.parse_args()
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
@ -179,7 +181,7 @@ parser_github = (
@console_ns.route("/workspaces/current/plugin/upload/github") @console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource): class PluginUploadFromGithubApi(Resource):
@api.expect(parser_github) @console_ns.expect(parser_github)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -228,7 +230,7 @@ parser_pkg = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/plugin/install/pkg") @console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource): class PluginInstallFromPkgApi(Resource):
@api.expect(parser_pkg) @console_ns.expect(parser_pkg)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -261,7 +263,7 @@ parser_githubapi = (
@console_ns.route("/workspaces/current/plugin/install/github") @console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource): class PluginInstallFromGithubApi(Resource):
@api.expect(parser_githubapi) @console_ns.expect(parser_githubapi)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -292,7 +294,7 @@ parser_marketplace = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/plugin/install/marketplace") @console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource): class PluginInstallFromMarketplaceApi(Resource):
@api.expect(parser_marketplace) @console_ns.expect(parser_marketplace)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -322,7 +324,7 @@ parser_pkgapi = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/plugin/marketplace/pkg") @console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource): class PluginFetchMarketplacePkgApi(Resource):
@api.expect(parser_pkgapi) @console_ns.expect(parser_pkgapi)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -351,7 +353,7 @@ parser_fetch = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/plugin/fetch-manifest") @console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource): class PluginFetchManifestApi(Resource):
@api.expect(parser_fetch) @console_ns.expect(parser_fetch)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -382,7 +384,7 @@ parser_tasks = (
@console_ns.route("/workspaces/current/plugin/tasks") @console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource): class PluginFetchInstallTasksApi(Resource):
@api.expect(parser_tasks) @console_ns.expect(parser_tasks)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -469,7 +471,7 @@ parser_marketplace_api = (
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace") @console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource): class PluginUpgradeFromMarketplaceApi(Resource):
@api.expect(parser_marketplace_api) @console_ns.expect(parser_marketplace_api)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -501,7 +503,7 @@ parser_github_post = (
@console_ns.route("/workspaces/current/plugin/upgrade/github") @console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource): class PluginUpgradeFromGithubApi(Resource):
@api.expect(parser_github_post) @console_ns.expect(parser_github_post)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -533,7 +535,7 @@ parser_uninstall = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/plugin/uninstall") @console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource): class PluginUninstallApi(Resource):
@api.expect(parser_uninstall) @console_ns.expect(parser_uninstall)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -558,7 +560,7 @@ parser_change_post = (
@console_ns.route("/workspaces/current/plugin/permission/change") @console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource): class PluginChangePermissionApi(Resource):
@api.expect(parser_change_post) @console_ns.expect(parser_change_post)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -616,16 +618,13 @@ parser_dynamic = (
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options") @console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource): class PluginFetchDynamicSelectOptionsApi(Resource):
@api.expect(parser_dynamic) @console_ns.expect(parser_dynamic)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
# check if the user is admin or owner
current_user, tenant_id = current_account_with_tenant() current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id user_id = current_user.id
args = parser_dynamic.parse_args() args = parser_dynamic.parse_args()
@ -656,7 +655,7 @@ parser_change = (
@console_ns.route("/workspaces/current/plugin/preferences/change") @console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource): class PluginChangePreferencesApi(Resource):
@api.expect(parser_change) @console_ns.expect(parser_change)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -750,7 +749,7 @@ parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, re
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude") @console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource): class PluginAutoUpgradeExcludePluginApi(Resource):
@api.expect(parser_exclude) @console_ns.expect(parser_exclude)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -770,9 +769,11 @@ class PluginReadmeApi(Resource):
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="args") .add_argument("plugin_unique_identifier", type=str, required=True, location="args")
.add_argument("language", type=str, required=False, location="args")
)
args = parser.parse_args() args = parser.parse_args()
return jsonable_encoder( return jsonable_encoder(
{ {

View File

@ -10,10 +10,11 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
enterprise_license_required, enterprise_license_required,
is_admin_or_owner_required,
setup_required, setup_required,
) )
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
@ -64,7 +65,7 @@ parser_tool = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-providers") @console_ns.route("/workspaces/current/tool-providers")
class ToolProviderListApi(Resource): class ToolProviderListApi(Resource):
@api.expect(parser_tool) @console_ns.expect(parser_tool)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -112,14 +113,13 @@ parser_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete") @console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
class ToolBuiltinProviderDeleteApi(Resource): class ToolBuiltinProviderDeleteApi(Resource):
@api.expect(parser_delete) @console_ns.expect(parser_delete)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
user, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
args = parser_delete.parse_args() args = parser_delete.parse_args()
@ -140,7 +140,7 @@ parser_add = (
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add") @console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
class ToolBuiltinProviderAddApi(Resource): class ToolBuiltinProviderAddApi(Resource):
@api.expect(parser_add) @console_ns.expect(parser_add)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -174,16 +174,13 @@ parser_update = (
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update") @console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
class ToolBuiltinProviderUpdateApi(Resource): class ToolBuiltinProviderUpdateApi(Resource):
@api.expect(parser_update) @console_ns.expect(parser_update)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_update.parse_args() args = parser_update.parse_args()
@ -239,16 +236,14 @@ parser_api_add = (
@console_ns.route("/workspaces/current/tool-provider/api/add") @console_ns.route("/workspaces/current/tool-provider/api/add")
class ToolApiProviderAddApi(Resource): class ToolApiProviderAddApi(Resource):
@api.expect(parser_api_add) @console_ns.expect(parser_api_add)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_api_add.parse_args() args = parser_api_add.parse_args()
@ -272,7 +267,7 @@ parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=
@console_ns.route("/workspaces/current/tool-provider/api/remote") @console_ns.route("/workspaces/current/tool-provider/api/remote")
class ToolApiProviderGetRemoteSchemaApi(Resource): class ToolApiProviderGetRemoteSchemaApi(Resource):
@api.expect(parser_remote) @console_ns.expect(parser_remote)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -297,7 +292,7 @@ parser_tools = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/api/tools") @console_ns.route("/workspaces/current/tool-provider/api/tools")
class ToolApiProviderListToolsApi(Resource): class ToolApiProviderListToolsApi(Resource):
@api.expect(parser_tools) @console_ns.expect(parser_tools)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -333,16 +328,14 @@ parser_api_update = (
@console_ns.route("/workspaces/current/tool-provider/api/update") @console_ns.route("/workspaces/current/tool-provider/api/update")
class ToolApiProviderUpdateApi(Resource): class ToolApiProviderUpdateApi(Resource):
@api.expect(parser_api_update) @console_ns.expect(parser_api_update)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_api_update.parse_args() args = parser_api_update.parse_args()
@ -369,16 +362,14 @@ parser_api_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/api/delete") @console_ns.route("/workspaces/current/tool-provider/api/delete")
class ToolApiProviderDeleteApi(Resource): class ToolApiProviderDeleteApi(Resource):
@api.expect(parser_api_delete) @console_ns.expect(parser_api_delete)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_api_delete.parse_args() args = parser_api_delete.parse_args()
@ -395,7 +386,7 @@ parser_get = reqparse.RequestParser().add_argument("provider", type=str, require
@console_ns.route("/workspaces/current/tool-provider/api/get") @console_ns.route("/workspaces/current/tool-provider/api/get")
class ToolApiProviderGetApi(Resource): class ToolApiProviderGetApi(Resource):
@api.expect(parser_get) @console_ns.expect(parser_get)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -435,7 +426,7 @@ parser_schema = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/api/schema") @console_ns.route("/workspaces/current/tool-provider/api/schema")
class ToolApiProviderSchemaApi(Resource): class ToolApiProviderSchemaApi(Resource):
@api.expect(parser_schema) @console_ns.expect(parser_schema)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -460,7 +451,7 @@ parser_pre = (
@console_ns.route("/workspaces/current/tool-provider/api/test/pre") @console_ns.route("/workspaces/current/tool-provider/api/test/pre")
class ToolApiProviderPreviousTestApi(Resource): class ToolApiProviderPreviousTestApi(Resource):
@api.expect(parser_pre) @console_ns.expect(parser_pre)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -493,16 +484,14 @@ parser_create = (
@console_ns.route("/workspaces/current/tool-provider/workflow/create") @console_ns.route("/workspaces/current/tool-provider/workflow/create")
class ToolWorkflowProviderCreateApi(Resource): class ToolWorkflowProviderCreateApi(Resource):
@api.expect(parser_create) @console_ns.expect(parser_create)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_create.parse_args() args = parser_create.parse_args()
@ -536,16 +525,13 @@ parser_workflow_update = (
@console_ns.route("/workspaces/current/tool-provider/workflow/update") @console_ns.route("/workspaces/current/tool-provider/workflow/update")
class ToolWorkflowProviderUpdateApi(Resource): class ToolWorkflowProviderUpdateApi(Resource):
@api.expect(parser_workflow_update) @console_ns.expect(parser_workflow_update)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_workflow_update.parse_args() args = parser_workflow_update.parse_args()
@ -574,16 +560,14 @@ parser_workflow_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/workflow/delete") @console_ns.route("/workspaces/current/tool-provider/workflow/delete")
class ToolWorkflowProviderDeleteApi(Resource): class ToolWorkflowProviderDeleteApi(Resource):
@api.expect(parser_workflow_delete) @console_ns.expect(parser_workflow_delete)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_workflow_delete.parse_args() args = parser_workflow_delete.parse_args()
@ -604,7 +588,7 @@ parser_wf_get = (
@console_ns.route("/workspaces/current/tool-provider/workflow/get") @console_ns.route("/workspaces/current/tool-provider/workflow/get")
class ToolWorkflowProviderGetApi(Resource): class ToolWorkflowProviderGetApi(Resource):
@api.expect(parser_wf_get) @console_ns.expect(parser_wf_get)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -640,7 +624,7 @@ parser_wf_tools = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/workflow/tools") @console_ns.route("/workspaces/current/tool-provider/workflow/tools")
class ToolWorkflowProviderListToolApi(Resource): class ToolWorkflowProviderListToolApi(Resource):
@api.expect(parser_wf_tools) @console_ns.expect(parser_wf_tools)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -734,18 +718,15 @@ class ToolLabelsApi(Resource):
class ToolPluginOAuthApi(Resource): class ToolPluginOAuthApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
tool_provider = ToolProviderID(provider) tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name provider_name = tool_provider.provider_name
# todo check permission
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider) oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None: if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider") raise Forbidden("no oauth available client config found for this tool provider")
@ -832,7 +813,7 @@ parser_default_cred = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential") @console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
class ToolBuiltinProviderSetDefaultApi(Resource): class ToolBuiltinProviderSetDefaultApi(Resource):
@api.expect(parser_default_cred) @console_ns.expect(parser_default_cred)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -853,17 +834,15 @@ parser_custom = (
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client") @console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
class ToolOAuthCustomClient(Resource): class ToolOAuthCustomClient(Resource):
@api.expect(parser_custom) @console_ns.expect(parser_custom)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider: str):
args = parser_custom.parse_args() args = parser_custom.parse_args()
user, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
return BuiltinToolManageService.save_custom_oauth_client_params( return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -953,7 +932,7 @@ parser_mcp_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/mcp") @console_ns.route("/workspaces/current/tool-provider/mcp")
class ToolProviderMCPApi(Resource): class ToolProviderMCPApi(Resource):
@api.expect(parser_mcp) @console_ns.expect(parser_mcp)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -983,7 +962,7 @@ class ToolProviderMCPApi(Resource):
) )
return jsonable_encoder(result) return jsonable_encoder(result)
@api.expect(parser_mcp_put) @console_ns.expect(parser_mcp_put)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -1022,7 +1001,7 @@ class ToolProviderMCPApi(Resource):
) )
return {"result": "success"} return {"result": "success"}
@api.expect(parser_mcp_delete) @console_ns.expect(parser_mcp_delete)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -1045,7 +1024,7 @@ parser_auth = (
@console_ns.route("/workspaces/current/tool-provider/mcp/auth") @console_ns.route("/workspaces/current/tool-provider/mcp/auth")
class ToolMCPAuthApi(Resource): class ToolMCPAuthApi(Resource):
@api.expect(parser_auth) @console_ns.expect(parser_auth)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -1086,7 +1065,13 @@ class ToolMCPAuthApi(Resource):
return {"result": "success"} return {"result": "success"}
except MCPAuthError as e: except MCPAuthError as e:
try: try:
auth_result = auth(provider_entity, args.get("authorization_code")) # Pass the extracted OAuth metadata hints to auth()
auth_result = auth(
provider_entity,
args.get("authorization_code"),
resource_metadata_url=e.resource_metadata_url,
scope_hint=e.scope_hint,
)
with Session(db.engine) as session, session.begin(): with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result) response = service.execute_auth_actions(auth_result)
@ -1096,7 +1081,7 @@ class ToolMCPAuthApi(Resource):
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except MCPError as e: except (MCPError, ValueError) as e:
with Session(db.engine) as session, session.begin(): with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
@ -1157,7 +1142,7 @@ parser_cb = (
@console_ns.route("/mcp/oauth/callback") @console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource): class ToolMCPCallbackApi(Resource):
@api.expect(parser_cb) @console_ns.expect(parser_cb)
def get(self): def get(self):
args = parser_cb.parse_args() args = parser_cb.parse_args()
state_key = args["state"] state_key = args["state"]

View File

@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from controllers.web.error import NotFoundError from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
@ -67,14 +67,12 @@ class TriggerProviderInfoApi(Resource):
class TriggerSubscriptionListApi(Resource): class TriggerSubscriptionListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider""" """List all trigger subscriptions for the current tenant's provider"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try: try:
return jsonable_encoder( return jsonable_encoder(
@ -92,17 +90,16 @@ class TriggerSubscriptionListApi(Resource):
class TriggerSubscriptionBuilderCreateApi(Resource): class TriggerSubscriptionBuilderCreateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
"""Add a new subscription instance for a trigger provider""" """Add a new subscription instance for a trigger provider"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json") "credential_type", type=str, required=False, nullable=True, location="json"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -133,18 +130,17 @@ class TriggerSubscriptionBuilderGetApi(Resource):
class TriggerSubscriptionBuilderVerifyApi(Resource): class TriggerSubscriptionBuilderVerifyApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider, subscription_builder_id): def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider""" """Verify a subscription instance for a trigger provider"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner: parser = (
raise Forbidden() reqparse.RequestParser()
parser = reqparse.RequestParser()
# The credentials of the subscription builder # The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -173,15 +169,17 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
assert isinstance(user, Account) assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
parser = reqparse.RequestParser() parser = (
reqparse.RequestParser()
# The name of the subscription builder # The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json") .add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder # The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json") .add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder # The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json") .add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder # The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
return jsonable_encoder( return jsonable_encoder(
@ -223,24 +221,23 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
class TriggerSubscriptionBuilderBuildApi(Resource): class TriggerSubscriptionBuilderBuildApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider, subscription_builder_id): def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider""" """Build a subscription instance for a trigger provider"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner: parser = (
raise Forbidden() reqparse.RequestParser()
parser = reqparse.RequestParser()
# The name of the subscription builder # The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json") .add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder # The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json") .add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder # The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json") .add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder # The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
# Use atomic update_and_build to prevent race conditions # Use atomic update_and_build to prevent race conditions
@ -264,14 +261,12 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
class TriggerSubscriptionDeleteApi(Resource): class TriggerSubscriptionDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, subscription_id: str): def post(self, subscription_id: str):
"""Delete a subscription instance""" """Delete a subscription instance"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try: try:
with Session(db.engine) as session: with Session(db.engine) as session:
@ -446,14 +441,12 @@ class TriggerOAuthCallbackApi(Resource):
class TriggerOAuthClientManageApi(Resource): class TriggerOAuthClientManageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
"""Get OAuth client configuration for a provider""" """Get OAuth client configuration for a provider"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try: try:
provider_id = TriggerProviderID(provider) provider_id = TriggerProviderID(provider)
@ -493,18 +486,18 @@ class TriggerOAuthClientManageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
"""Configure custom OAuth client for a provider""" """Configure custom OAuth client for a provider"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") reqparse.RequestParser()
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json") .add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -524,14 +517,12 @@ class TriggerOAuthClientManageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def delete(self, provider): def delete(self, provider):
"""Remove custom OAuth client configuration""" """Remove custom OAuth client configuration"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try: try:
provider_id = TriggerProviderID(provider) provider_id = TriggerProviderID(provider)
@ -548,45 +539,49 @@ class TriggerOAuthClientManageApi(Resource):
# Trigger Subscription # Trigger Subscription
api.add_resource(TriggerProviderIconApi, "/workspaces/current/trigger-provider/<path:provider>/icon") console_ns.add_resource(TriggerProviderIconApi, "/workspaces/current/trigger-provider/<path:provider>/icon")
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers") console_ns.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info") console_ns.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list") console_ns.add_resource(
api.add_resource( TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list"
)
console_ns.add_resource(
TriggerSubscriptionDeleteApi, TriggerSubscriptionDeleteApi,
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete", "/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
) )
# Trigger Subscription Builder # Trigger Subscription Builder
api.add_resource( console_ns.add_resource(
TriggerSubscriptionBuilderCreateApi, TriggerSubscriptionBuilderCreateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create", "/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
) )
api.add_resource( console_ns.add_resource(
TriggerSubscriptionBuilderGetApi, TriggerSubscriptionBuilderGetApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>", "/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
) )
api.add_resource( console_ns.add_resource(
TriggerSubscriptionBuilderUpdateApi, TriggerSubscriptionBuilderUpdateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>", "/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
) )
api.add_resource( console_ns.add_resource(
TriggerSubscriptionBuilderVerifyApi, TriggerSubscriptionBuilderVerifyApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>", "/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
) )
api.add_resource( console_ns.add_resource(
TriggerSubscriptionBuilderBuildApi, TriggerSubscriptionBuilderBuildApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>", "/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
) )
api.add_resource( console_ns.add_resource(
TriggerSubscriptionBuilderLogsApi, TriggerSubscriptionBuilderLogsApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>", "/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
) )
# OAuth # OAuth
api.add_resource( console_ns.add_resource(
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize" TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
) )
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback") console_ns.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client") console_ns.add_resource(
TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client"
)

View File

@ -13,7 +13,7 @@ from controllers.common.errors import (
TooManyFilesError, TooManyFilesError,
UnsupportedFileTypeError, UnsupportedFileTypeError,
) )
from controllers.console import api, console_ns from controllers.console import console_ns
from controllers.console.admin import admin_required from controllers.console.admin import admin_required
from controllers.console.error import AccountNotLinkTenantError from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import ( from controllers.console.wraps import (
@ -128,7 +128,7 @@ class TenantApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(tenant_fields) @marshal_with(tenant_fields)
def get(self): def post(self):
if request.path == "/info": if request.path == "/info":
logger.warning("Deprecated URL /info was used.") logger.warning("Deprecated URL /info was used.")
@ -155,7 +155,7 @@ parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, req
@console_ns.route("/workspaces/switch") @console_ns.route("/workspaces/switch")
class SwitchWorkspaceApi(Resource): class SwitchWorkspaceApi(Resource):
@api.expect(parser_switch) @console_ns.expect(parser_switch)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -250,7 +250,7 @@ parser_info = reqparse.RequestParser().add_argument("name", type=str, required=T
@console_ns.route("/workspaces/info") @console_ns.route("/workspaces/info")
class WorkspaceInfoApi(Resource): class WorkspaceInfoApi(Resource):
@api.expect(parser_info) @console_ns.expect(parser_info)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -315,3 +315,19 @@ def edit_permission_required(f: Callable[P, R]):
return f(*args, **kwargs) return f(*args, **kwargs)
return decorated_function return decorated_function
def is_admin_or_owner_required(f: Callable[P, R]):
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
from libs.login import current_user
from models import Account
user = current_user._get_current_object()
if not isinstance(user, Account) or not user.is_admin_or_owner:
raise Forbidden()
return f(*args, **kwargs)
return decorated_function

View File

@ -3,14 +3,12 @@ from typing import Literal
from flask import request from flask import request
from flask_restx import Api, Namespace, Resource, fields, reqparse from flask_restx import Api, Namespace, Resource, fields, reqparse
from flask_restx.api import HTTPStatus from flask_restx.api import HTTPStatus
from werkzeug.exceptions import Forbidden
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_fields, build_annotation_model from fields.annotation_fields import annotation_fields, build_annotation_model
from libs.login import current_user
from models import Account
from models.model import App from models.model import App
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
@ -161,14 +159,10 @@ class AnnotationUpdateDeleteApi(Resource):
} }
) )
@validate_app_token @validate_app_token
@edit_permission_required
@service_api_ns.marshal_with(build_annotation_model(service_api_ns)) @service_api_ns.marshal_with(build_annotation_model(service_api_ns))
def put(self, app_model: App, annotation_id): def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation.""" """Update an existing annotation."""
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
annotation_id = str(annotation_id)
args = annotation_create_parser.parse_args() args = annotation_create_parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation return annotation
@ -185,13 +179,8 @@ class AnnotationUpdateDeleteApi(Resource):
} }
) )
@validate_app_token @validate_app_token
def delete(self, app_model: App, annotation_id): @edit_permission_required
def delete(self, app_model: App, annotation_id: str):
"""Delete an annotation.""" """Delete an annotation."""
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id) AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -5,6 +5,7 @@ from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
from controllers.service_api.wraps import ( from controllers.service_api.wraps import (
@ -619,11 +620,9 @@ class DatasetTagsApi(DatasetApiResource):
} }
) )
@validate_dataset_token @validate_dataset_token
@edit_permission_required
def delete(self, _, dataset_id): def delete(self, _, dataset_id):
"""Delete a knowledge type tag.""" """Delete a knowledge type tag."""
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
args = tag_delete_parser.parse_args() args = tag_delete_parser.parse_args()
TagService.delete_tag(args["tag_id"]) TagService.delete_tag(args["tag_id"])

View File

@ -1,7 +1,10 @@
import json import json
from typing import Self
from uuid import UUID
from flask import request from flask import request
from flask_restx import marshal, reqparse from flask_restx import marshal, reqparse
from pydantic import BaseModel, model_validator
from sqlalchemy import desc, select from sqlalchemy import desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -31,7 +34,7 @@ from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService from services.file_service import FileService
# Define parsers for document operations # Define parsers for document operations
@ -51,15 +54,26 @@ document_text_create_parser = (
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
) )
document_text_update_parser = ( DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("text", type=str, required=False, nullable=True, location="json") class DocumentTextUpdate(BaseModel):
.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") name: str | None = None
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") text: str | None = None
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") process_rule: ProcessRule | None = None
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") doc_form: str = "text_model"
) doc_language: str = "English"
retrieval_model: RetrievalModel | None = None
@model_validator(mode="after")
def check_text_and_name(self) -> Self:
if self.text is not None and self.name is None:
raise ValueError("name is required when text is provided")
return self
for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
@service_api_ns.route( @service_api_ns.route(
@ -160,7 +174,7 @@ class DocumentAddByTextApi(DatasetApiResource):
class DocumentUpdateByTextApi(DatasetApiResource): class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents.""" """Resource for update documents."""
@service_api_ns.expect(document_text_update_parser) @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True)
@service_api_ns.doc("update_document_by_text") @service_api_ns.doc("update_document_by_text")
@service_api_ns.doc(description="Update an existing document by providing text content") @service_api_ns.doc(description="Update an existing document by providing text content")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -173,12 +187,10 @@ class DocumentUpdateByTextApi(DatasetApiResource):
) )
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id): def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by text.""" """Update document by text."""
args = document_text_update_parser.parse_args() args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True)
dataset_id = str(dataset_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
@ -198,11 +210,9 @@ class DocumentUpdateByTextApi(DatasetApiResource):
# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique args["indexing_technique"] = dataset.indexing_technique
if args["text"]: if args.get("text"):
text = args.get("text") text = args.get("text")
name = args.get("name") name = args.get("name")
if text is None or name is None:
raise ValueError("Both text and name must be strings.")
if not current_user: if not current_user:
raise ValueError("current_user is required") raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text( upload_file = FileService(db.engine).upload_text(
@ -456,12 +466,16 @@ class DocumentListApi(DatasetApiResource):
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str) search = request.args.get("keyword", default=None, type=str)
status = request.args.get("status", default=None, type=str)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
if status:
query = DocumentService.apply_display_status_filter(query, status)
if search: if search:
search = f"%{search}%" search = f"%{search}%"
query = query.where(Document.name.like(search)) query = query.where(Document.name.like(search))

View File

@ -81,6 +81,7 @@ class LoginStatusApi(Resource):
) )
def get(self): def get(self):
app_code = request.args.get("app_code") app_code = request.args.get("app_code")
user_id = request.args.get("user_id")
token = extract_webapp_access_token(request) token = extract_webapp_access_token(request)
if not app_code: if not app_code:
return { return {
@ -103,7 +104,7 @@ class LoginStatusApi(Resource):
user_logged_in = False user_logged_in = False
try: try:
_ = decode_jwt_token(app_code=app_code) _ = decode_jwt_token(app_code=app_code, user_id=user_id)
app_logged_in = True app_logged_in = True
except Exception: except Exception:
app_logged_in = False app_logged_in = False

View File

@ -38,7 +38,7 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None =
return decorator return decorator
def decode_jwt_token(app_code: str | None = None): def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
system_features = FeatureService.get_system_features() system_features = FeatureService.get_system_features()
if not app_code: if not app_code:
app_code = str(request.headers.get(HEADER_NAME_APP_CODE)) app_code = str(request.headers.get(HEADER_NAME_APP_CODE))
@ -63,6 +63,10 @@ def decode_jwt_token(app_code: str | None = None):
if not end_user: if not end_user:
raise NotFound() raise NotFound()
# Validate user_id against end_user's session_id if provided
if user_id is not None and end_user.session_id != user_id:
raise Unauthorized("Authentication has expired.")
# for enterprise webapp auth # for enterprise webapp auth
app_web_auth_enabled = False app_web_auth_enabled = False
webapp_settings = None webapp_settings = None

View File

@ -112,6 +112,7 @@ class VariableEntity(BaseModel):
type: VariableEntityType type: VariableEntityType
required: bool = False required: bool = False
hide: bool = False hide: bool = False
default: Any = None
max_length: int | None = None max_length: int | None = None
options: Sequence[str] = Field(default_factory=list) options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)

View File

@ -93,7 +93,11 @@ class BaseAppGenerator:
if value is None: if value is None:
if variable_entity.required: if variable_entity.required:
raise ValueError(f"{variable_entity.variable} is required in input form") raise ValueError(f"{variable_entity.variable} is required in input form")
return value # Use default value and continue validation to ensure type conversion
value = variable_entity.default
# If default is also None, return None directly
if value is None:
return None
if variable_entity.type in { if variable_entity.type in {
VariableEntityType.TEXT_INPUT, VariableEntityType.TEXT_INPUT,

View File

@ -163,7 +163,7 @@ class PipelineGenerator(BaseAppGenerator):
datasource_type=datasource_type, datasource_type=datasource_type,
datasource_info=json.dumps(datasource_info), datasource_info=json.dumps(datasource_info),
datasource_node_id=start_node_id, datasource_node_id=start_node_id,
input_data=inputs, input_data=dict(inputs),
pipeline_id=pipeline.id, pipeline_id=pipeline.id,
created_by=user.id, created_by=user.id,
) )

View File

@ -145,7 +145,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
**extract_external_trace_id_from_args(args), **extract_external_trace_id_from_args(args),
} }
workflow_run_id = str(uuid.uuid4()) workflow_run_id = str(uuid.uuid4())
# for trigger debug run, not prepare user inputs # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs
if self._should_prepare_user_inputs(args): if self._should_prepare_user_inputs(args):
inputs = self._prepare_user_inputs( inputs = self._prepare_user_inputs(
user_inputs=inputs, user_inputs=inputs,

View File

@ -644,14 +644,15 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if not workflow_run_id: if not workflow_run_id:
return return
workflow_app_log = WorkflowAppLog() workflow_app_log = WorkflowAppLog(
workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id tenant_id=self._application_generate_entity.app_config.tenant_id,
workflow_app_log.app_id = self._application_generate_entity.app_config.app_id app_id=self._application_generate_entity.app_config.app_id,
workflow_app_log.workflow_id = self._workflow.id workflow_id=self._workflow.id,
workflow_app_log.workflow_run_id = workflow_run_id workflow_run_id=workflow_run_id,
workflow_app_log.created_from = created_from.value created_from=created_from.value,
workflow_app_log.created_by_role = self._created_by_role created_by_role=self._created_by_role,
workflow_app_log.created_by = self._user_id created_by=self._user_id,
)
session.add(workflow_app_log) session.add(workflow_app_log)
session.commit() session.commit()

View File

@ -1,14 +1,10 @@
from typing import TYPE_CHECKING, Any, Optional from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# Import InvokeFrom locally to avoid circular import
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
if TYPE_CHECKING:
from core.app.entities.app_invoke_entities import InvokeFrom
class DatasourceRuntime(BaseModel): class DatasourceRuntime(BaseModel):
""" """
@ -17,7 +13,7 @@ class DatasourceRuntime(BaseModel):
tenant_id: str tenant_id: str
datasource_id: str | None = None datasource_id: str | None = None
invoke_from: Optional["InvokeFrom"] = None invoke_from: InvokeFrom | None = None
datasource_invoke_from: DatasourceInvokeFrom | None = None datasource_invoke_from: DatasourceInvokeFrom | None = None
credentials: dict[str, Any] = Field(default_factory=dict) credentials: dict[str, Any] = Field(default_factory=dict)
runtime_parameters: dict[str, Any] = Field(default_factory=dict) runtime_parameters: dict[str, Any] = Field(default_factory=dict)

View File

@ -6,7 +6,8 @@ import secrets
import urllib.parse import urllib.parse
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
from httpx import ConnectError, HTTPStatusError, RequestError import httpx
from httpx import RequestError
from pydantic import ValidationError from pydantic import ValidationError
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
@ -20,6 +21,7 @@ from core.mcp.types import (
OAuthClientMetadata, OAuthClientMetadata,
OAuthMetadata, OAuthMetadata,
OAuthTokens, OAuthTokens,
ProtectedResourceMetadata,
) )
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -39,6 +41,131 @@ def generate_pkce_challenge() -> tuple[str, str]:
return code_verifier, code_challenge return code_verifier, code_challenge
def build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url: str | None, server_url: str
) -> list[str]:
"""
Build a list of URLs to try for Protected Resource Metadata discovery.
Per SEP-985, supports fallback when discovery fails at one URL.
"""
urls = []
# First priority: URL from WWW-Authenticate header
if www_auth_resource_metadata_url:
urls.append(www_auth_resource_metadata_url)
# Fallback: construct from server URL
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
if fallback_url not in urls:
urls.append(fallback_url)
return urls
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
"""
Build a list of URLs to try for OAuth Authorization Server Metadata discovery.
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
Per RFC 8414 section 3:
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
Example:
- issuer: https://example.com/oauth
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
"""
urls = []
base_url = auth_server_url or server_url
parsed = urlparse(base_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/") # Remove trailing slash
# Try OpenID Connect discovery first (more common)
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
# Include the path component if present in the issuer URL
if path:
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
else:
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
return urls
def discover_protected_resource_metadata(
prm_url: str | None, server_url: str, protocol_version: str | None = None
) -> ProtectedResourceMetadata | None:
"""Discover OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
urls = build_protected_resource_metadata_discovery_urls(prm_url, server_url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
for url in urls:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 200:
return ProtectedResourceMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
continue # Try next URL
return None
def discover_oauth_authorization_server_metadata(
auth_server_url: str | None, server_url: str, protocol_version: str | None = None
) -> OAuthMetadata | None:
"""Discover OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url)
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
for url in urls:
try:
response = ssrf_proxy.get(url, headers=headers)
if response.status_code == 200:
return OAuthMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
continue # Try next URL
return None
def get_effective_scope(
scope_from_www_auth: str | None,
prm: ProtectedResourceMetadata | None,
asm: OAuthMetadata | None,
client_scope: str | None,
) -> str | None:
"""
Determine effective scope using priority-based selection strategy.
Priority order:
1. WWW-Authenticate header scope (server explicit requirement)
2. Protected Resource Metadata scopes
3. OAuth Authorization Server Metadata scopes
4. Client configured scope
"""
if scope_from_www_auth:
return scope_from_www_auth
if prm and prm.scopes_supported:
return " ".join(prm.scopes_supported)
if asm and asm.scopes_supported:
return " ".join(asm.scopes_supported)
return client_scope
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str: def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
"""Create a secure state parameter by storing state data in Redis and returning a random state key.""" """Create a secure state parameter by storing state data in Redis and returning a random state key."""
# Generate a secure random state key # Generate a secure random state key
@ -121,42 +248,36 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
return False, "" return False, ""
def discover_oauth_metadata(server_url: str, protocol_version: str | None = None) -> OAuthMetadata | None: def discover_oauth_metadata(
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata.""" server_url: str,
# First check if the server supports OAuth 2.0 Resource Discovery resource_metadata_url: str | None = None,
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url) scope_hint: str | None = None,
if support_resource_discovery: protocol_version: str | None = None,
# The oauth_discovery_url is the authorization server base URL ) -> tuple[OAuthMetadata | None, ProtectedResourceMetadata | None, str | None]:
# Try OpenID Connect discovery first (more common), then OAuth 2.0 """
urls_to_try = [ Discover OAuth metadata using RFC 8414/9470 standards.
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
]
else:
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION} Args:
server_url: The MCP server URL
resource_metadata_url: Protected Resource Metadata URL from WWW-Authenticate header
scope_hint: Scope hint from WWW-Authenticate header
protocol_version: MCP protocol version
for url in urls_to_try: Returns:
try: (oauth_metadata, protected_resource_metadata, scope_hint)
response = ssrf_proxy.get(url, headers=headers) """
if response.status_code == 404: # Discover Protected Resource Metadata
continue prm = discover_protected_resource_metadata(resource_metadata_url, server_url, protocol_version)
if not response.is_success:
response.raise_for_status()
return OAuthMetadata.model_validate(response.json())
except (RequestError, HTTPStatusError) as e:
if isinstance(e, ConnectError):
response = ssrf_proxy.get(url)
if response.status_code == 404:
continue # Try next URL
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
# For other errors, try next URL
continue
return None # No metadata found # Get authorization server URL from PRM or use server URL
auth_server_url = None
if prm and prm.authorization_servers:
auth_server_url = prm.authorization_servers[0]
# Discover OAuth Authorization Server Metadata
asm = discover_oauth_authorization_server_metadata(auth_server_url, server_url, protocol_version)
return asm, prm, scope_hint
def start_authorization( def start_authorization(
@ -166,6 +287,7 @@ def start_authorization(
redirect_url: str, redirect_url: str,
provider_id: str, provider_id: str,
tenant_id: str, tenant_id: str,
scope: str | None = None,
) -> tuple[str, str]: ) -> tuple[str, str]:
"""Begins the authorization flow with secure Redis state storage.""" """Begins the authorization flow with secure Redis state storage."""
response_type = "code" response_type = "code"
@ -175,13 +297,6 @@ def start_authorization(
authorization_url = metadata.authorization_endpoint authorization_url = metadata.authorization_endpoint
if response_type not in metadata.response_types_supported: if response_type not in metadata.response_types_supported:
raise ValueError(f"Incompatible auth server: does not support response type {response_type}") raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
if (
not metadata.code_challenge_methods_supported
or code_challenge_method not in metadata.code_challenge_methods_supported
):
raise ValueError(
f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
)
else: else:
authorization_url = urljoin(server_url, "/authorize") authorization_url = urljoin(server_url, "/authorize")
@ -210,10 +325,49 @@ def start_authorization(
"state": state_key, "state": state_key,
} }
# Add scope if provided
if scope:
params["scope"] = scope
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}" authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
return authorization_url, code_verifier return authorization_url, code_verifier
def _parse_token_response(response: httpx.Response) -> OAuthTokens:
"""
Parse OAuth token response supporting both JSON and form-urlencoded formats.
Per RFC 6749 Section 5.1, the standard format is JSON.
However, some legacy OAuth providers (e.g., early GitHub OAuth Apps) return
application/x-www-form-urlencoded format for backwards compatibility.
Args:
response: The HTTP response from token endpoint
Returns:
Parsed OAuth tokens
Raises:
ValueError: If response cannot be parsed
"""
content_type = response.headers.get("content-type", "").lower()
if "application/json" in content_type:
# Standard OAuth 2.0 JSON response (RFC 6749)
return OAuthTokens.model_validate(response.json())
elif "application/x-www-form-urlencoded" in content_type:
# Legacy form-urlencoded response (non-standard but used by some providers)
token_data = dict(urllib.parse.parse_qsl(response.text))
return OAuthTokens.model_validate(token_data)
else:
# No content-type or unknown - try JSON first, fallback to form-urlencoded
try:
return OAuthTokens.model_validate(response.json())
except (ValidationError, json.JSONDecodeError):
token_data = dict(urllib.parse.parse_qsl(response.text))
return OAuthTokens.model_validate(token_data)
def exchange_authorization( def exchange_authorization(
server_url: str, server_url: str,
metadata: OAuthMetadata | None, metadata: OAuthMetadata | None,
@ -246,7 +400,7 @@ def exchange_authorization(
response = ssrf_proxy.post(token_url, data=params) response = ssrf_proxy.post(token_url, data=params)
if not response.is_success: if not response.is_success:
raise ValueError(f"Token exchange failed: HTTP {response.status_code}") raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json()) return _parse_token_response(response)
def refresh_authorization( def refresh_authorization(
@ -279,7 +433,7 @@ def refresh_authorization(
raise MCPRefreshTokenError(e) from e raise MCPRefreshTokenError(e) from e
if not response.is_success: if not response.is_success:
raise MCPRefreshTokenError(response.text) raise MCPRefreshTokenError(response.text)
return OAuthTokens.model_validate(response.json()) return _parse_token_response(response)
def client_credentials_flow( def client_credentials_flow(
@ -322,7 +476,7 @@ def client_credentials_flow(
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}" f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
) )
return OAuthTokens.model_validate(response.json()) return _parse_token_response(response)
def register_client( def register_client(
@ -352,6 +506,8 @@ def auth(
provider: MCPProviderEntity, provider: MCPProviderEntity,
authorization_code: str | None = None, authorization_code: str | None = None,
state_param: str | None = None, state_param: str | None = None,
resource_metadata_url: str | None = None,
scope_hint: str | None = None,
) -> AuthResult: ) -> AuthResult:
""" """
Orchestrates the full auth flow with a server using secure Redis state storage. Orchestrates the full auth flow with a server using secure Redis state storage.
@ -363,18 +519,26 @@ def auth(
provider: The MCP provider entity provider: The MCP provider entity
authorization_code: Optional authorization code from OAuth callback authorization_code: Optional authorization code from OAuth callback
state_param: Optional state parameter from OAuth callback state_param: Optional state parameter from OAuth callback
resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate
scope_hint: Optional scope hint from WWW-Authenticate header
Returns: Returns:
AuthResult containing actions to be performed and response data AuthResult containing actions to be performed and response data
""" """
actions: list[AuthAction] = [] actions: list[AuthAction] = []
server_url = provider.decrypt_server_url() server_url = provider.decrypt_server_url()
server_metadata = discover_oauth_metadata(server_url)
# Discover OAuth metadata using RFC 8414/9470 standards
server_metadata, prm, scope_from_www_auth = discover_oauth_metadata(
server_url, resource_metadata_url, scope_hint, LATEST_PROTOCOL_VERSION
)
client_metadata = provider.client_metadata client_metadata = provider.client_metadata
provider_id = provider.id provider_id = provider.id
tenant_id = provider.tenant_id tenant_id = provider.tenant_id
client_information = provider.retrieve_client_information() client_information = provider.retrieve_client_information()
redirect_url = provider.redirect_url redirect_url = provider.redirect_url
credentials = provider.decrypt_credentials()
# Determine grant type based on server metadata # Determine grant type based on server metadata
if not server_metadata: if not server_metadata:
@ -392,8 +556,8 @@ def auth(
else: else:
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
# Get stored credentials # Determine effective scope using priority-based strategy
credentials = provider.decrypt_credentials() effective_scope = get_effective_scope(scope_from_www_auth, prm, server_metadata, credentials.get("scope"))
if not client_information: if not client_information:
if authorization_code is not None: if authorization_code is not None:
@ -425,12 +589,11 @@ def auth(
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value: if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
# Direct token request without user interaction # Direct token request without user interaction
try: try:
scope = credentials.get("scope")
tokens = client_credentials_flow( tokens = client_credentials_flow(
server_url, server_url,
server_metadata, server_metadata,
client_information, client_information,
scope, effective_scope,
) )
# Return action to save tokens and grant type # Return action to save tokens and grant type
@ -526,6 +689,7 @@ def auth(
redirect_url, redirect_url,
provider_id, provider_id,
tenant_id, tenant_id,
effective_scope,
) )
# Return action to save code verifier # Return action to save code verifier

View File

@ -90,7 +90,13 @@ class MCPClientWithAuthRetry(MCPClient):
mcp_service = MCPToolManageService(session=session) mcp_service = MCPToolManageService(session=session)
# Perform authentication using the service's auth method # Perform authentication using the service's auth method
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code) # Extract OAuth metadata hints from the error
mcp_service.auth_with_actions(
self.provider_entity,
self.authorization_code,
resource_metadata_url=error.resource_metadata_url,
scope_hint=error.scope_hint,
)
# Retrieve new tokens # Retrieve new tokens
self.provider_entity = mcp_service.get_provider_entity( self.provider_entity = mcp_service.get_provider_entity(

View File

@ -290,7 +290,7 @@ def sse_client(
except httpx.HTTPStatusError as exc: except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401: if exc.response.status_code == 401:
raise MCPAuthError() raise MCPAuthError(response=exc.response)
raise MCPConnectionError() raise MCPConnectionError()
except Exception: except Exception:
logger.exception("Error connecting to SSE endpoint") logger.exception("Error connecting to SSE endpoint")

View File

@ -1,3 +1,10 @@
import re
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import httpx
class MCPError(Exception): class MCPError(Exception):
pass pass
@ -7,7 +14,49 @@ class MCPConnectionError(MCPError):
class MCPAuthError(MCPConnectionError): class MCPAuthError(MCPConnectionError):
pass def __init__(
self,
message: str | None = None,
response: "httpx.Response | None" = None,
www_authenticate_header: str | None = None,
):
"""
MCP Authentication Error.
Args:
message: Error message
response: HTTP response object (will extract WWW-Authenticate header if provided)
www_authenticate_header: Pre-extracted WWW-Authenticate header value
"""
super().__init__(message or "Authentication failed")
# Extract OAuth metadata hints from WWW-Authenticate header
if response is not None:
www_authenticate_header = response.headers.get("WWW-Authenticate")
self.resource_metadata_url: str | None = None
self.scope_hint: str | None = None
if www_authenticate_header:
self.resource_metadata_url = self._extract_field(www_authenticate_header, "resource_metadata")
self.scope_hint = self._extract_field(www_authenticate_header, "scope")
@staticmethod
def _extract_field(www_auth: str, field_name: str) -> str | None:
"""Extract a specific field from the WWW-Authenticate header."""
# Pattern to match field="value" or field=value
pattern = rf'{field_name}="([^"]*)"'
match = re.search(pattern, www_auth)
if match:
return match.group(1)
# Try without quotes
pattern = rf"{field_name}=([^\s,]+)"
match = re.search(pattern, www_auth)
if match:
return match.group(1)
return None
class MCPRefreshTokenError(MCPError): class MCPRefreshTokenError(MCPError):

View File

@ -149,7 +149,7 @@ class BaseSession(
messages when entered. messages when entered.
""" """
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]] _response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError]]
_request_id: int _request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_receive_request_type: type[ReceiveRequestT] _receive_request_type: type[ReceiveRequestT]
@ -230,7 +230,7 @@ class BaseSession(
request_id = self._request_id request_id = self._request_id
self._request_id = request_id + 1 self._request_id = request_id + 1
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue() response_queue: queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError] = queue.Queue()
self._response_streams[request_id] = response_queue self._response_streams[request_id] = response_queue
try: try:
@ -261,11 +261,17 @@ class BaseSession(
message="No response received", message="No response received",
) )
) )
elif isinstance(response_or_error, HTTPStatusError):
# HTTPStatusError from streamable_client with preserved response object
if response_or_error.response.status_code == 401:
raise MCPAuthError(response=response_or_error.response)
else:
raise MCPConnectionError(
ErrorData(code=response_or_error.response.status_code, message=str(response_or_error))
)
elif isinstance(response_or_error, JSONRPCError): elif isinstance(response_or_error, JSONRPCError):
if response_or_error.error.code == 401: if response_or_error.error.code == 401:
raise MCPAuthError( raise MCPAuthError(message=response_or_error.error.message)
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
)
else: else:
raise MCPConnectionError( raise MCPConnectionError(
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message) ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
@ -327,6 +333,10 @@ class BaseSession(
if isinstance(message, HTTPStatusError): if isinstance(message, HTTPStatusError):
response_queue = self._response_streams.get(self._request_id - 1) response_queue = self._response_streams.get(self._request_id - 1)
if response_queue is not None: if response_queue is not None:
# For 401 errors, pass the HTTPStatusError directly to preserve response object
if message.response.status_code == 401:
response_queue.put(message)
else:
response_queue.put( response_queue.put(
JSONRPCError( JSONRPCError(
jsonrpc="2.0", jsonrpc="2.0",

View File

@ -23,7 +23,7 @@ for reference.
not separate types in the schema. not separate types in the schema.
""" """
# Client support both version, not support 2025-06-18 yet. # Client support both version, not support 2025-06-18 yet.
LATEST_PROTOCOL_VERSION = "2025-03-26" LATEST_PROTOCOL_VERSION = "2025-06-18"
# Server support 2024-11-05 to allow claude to use. # Server support 2024-11-05 to allow claude to use.
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05" SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
DEFAULT_NEGOTIATED_VERSION = "2025-03-26" DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
@ -1330,3 +1330,13 @@ class OAuthMetadata(BaseModel):
response_types_supported: list[str] response_types_supported: list[str]
grant_types_supported: list[str] | None = None grant_types_supported: list[str] | None = None
code_challenge_methods_supported: list[str] | None = None code_challenge_methods_supported: list[str] | None = None
scopes_supported: list[str] | None = None
class ProtectedResourceMetadata(BaseModel):
"""OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
resource: str | None = None
authorization_servers: list[str]
scopes_supported: list[str] | None = None
bearer_methods_supported: list[str] | None = None

View File

@ -2,7 +2,7 @@ from enum import StrEnum
from pydantic import BaseModel, ValidationInfo, field_validator from pydantic import BaseModel, ValidationInfo, field_validator
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path
class TracingProviderEnum(StrEnum): class TracingProviderEnum(StrEnum):
@ -13,6 +13,8 @@ class TracingProviderEnum(StrEnum):
OPIK = "opik" OPIK = "opik"
WEAVE = "weave" WEAVE = "weave"
ALIYUN = "aliyun" ALIYUN = "aliyun"
MLFLOW = "mlflow"
DATABRICKS = "databricks"
TENCENT = "tencent" TENCENT = "tencent"
@ -223,5 +225,47 @@ class TencentConfig(BaseTracingConfig):
return cls.validate_project_field(v, "dify_app") return cls.validate_project_field(v, "dify_app")
class MLflowConfig(BaseTracingConfig):
"""
Model class for MLflow tracing config.
"""
tracking_uri: str = "http://localhost:5000"
experiment_id: str = "0" # Default experiment id in MLflow is 0
username: str | None = None
password: str | None = None
@field_validator("tracking_uri")
@classmethod
def tracking_uri_validator(cls, v, info: ValidationInfo):
if isinstance(v, str) and v.startswith("databricks"):
raise ValueError(
"Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
)
return validate_url_with_path(v, "http://localhost:5000")
@field_validator("experiment_id")
@classmethod
def experiment_id_validator(cls, v, info: ValidationInfo):
return validate_integer_id(v)
class DatabricksConfig(BaseTracingConfig):
"""
Model class for Databricks (Databricks-managed MLflow) tracing config.
"""
experiment_id: str
host: str
client_id: str | None = None
client_secret: str | None = None
personal_access_token: str | None = None
@field_validator("experiment_id")
@classmethod
def experiment_id_validator(cls, v, info: ValidationInfo):
return validate_integer_id(v)
OPS_FILE_PATH = "ops_trace/" OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@ -0,0 +1,549 @@
import json
import logging
import os
from datetime import datetime, timedelta
from typing import Any, cast
import mlflow
from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType
from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey
from mlflow.tracing.fluent import start_span_no_context, update_current_trace
from mlflow.tracing.provider import detach_span_from_context, set_span_in_context
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from core.workflow.enums import NodeType
from extensions.ext_database import db
from models import EndUser
from models.workflow import WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
def datetime_to_nanoseconds(dt: datetime | None) -> int | None:
"""Convert datetime to nanosecond timestamp for MLflow API"""
if dt is None:
return None
return int(dt.timestamp() * 1_000_000_000)
class MLflowDataTrace(BaseTraceInstance):
def __init__(self, config: MLflowConfig | DatabricksConfig):
super().__init__(config)
if isinstance(config, DatabricksConfig):
self._setup_databricks(config)
else:
self._setup_mlflow(config)
# Enable async logging to minimize performance overhead
os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] = "true"
def _setup_databricks(self, config: DatabricksConfig):
"""Setup connection to Databricks-managed MLflow instances"""
os.environ["DATABRICKS_HOST"] = config.host
if config.client_id and config.client_secret:
# OAuth: https://docs.databricks.com/aws/en/dev-tools/auth/oauth-m2m?language=Environment
os.environ["DATABRICKS_CLIENT_ID"] = config.client_id
os.environ["DATABRICKS_CLIENT_SECRET"] = config.client_secret
elif config.personal_access_token:
# PAT: https://docs.databricks.com/aws/en/dev-tools/auth/pat
os.environ["DATABRICKS_TOKEN"] = config.personal_access_token
else:
raise ValueError(
"Either Databricks token (PAT) or client id and secret (OAuth) must be provided"
"See https://docs.databricks.com/aws/en/dev-tools/auth/#what-authorization-option-should-i-choose "
"for more information about the authorization options."
)
mlflow.set_tracking_uri("databricks")
mlflow.set_experiment(experiment_id=config.experiment_id)
# Remove trailing slash from host
config.host = config.host.rstrip("/")
self._project_url = f"{config.host}/ml/experiments/{config.experiment_id}/traces"
def _setup_mlflow(self, config: MLflowConfig):
"""Setup connection to MLflow instances"""
mlflow.set_tracking_uri(config.tracking_uri)
mlflow.set_experiment(experiment_id=config.experiment_id)
# Simple auth if provided
if config.username and config.password:
os.environ["MLFLOW_TRACKING_USERNAME"] = config.username
os.environ["MLFLOW_TRACKING_PASSWORD"] = config.password
self._project_url = f"{config.tracking_uri}/#/experiments/{config.experiment_id}/traces"
def trace(self, trace_info: BaseTraceInfo):
"""Simple dispatch to trace methods"""
try:
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
elif isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
elif isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
elif isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
elif isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
except Exception:
logger.exception("[MLflow] Trace error")
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
"""Create workflow span as root, with node spans as children"""
# fields with sys.xyz is added by Dify, they are duplicate to trace_info.metadata
raw_inputs = trace_info.workflow_run_inputs or {}
workflow_inputs = {k: v for k, v in raw_inputs.items() if not k.startswith("sys.")}
# Special inputs propagated by system
if trace_info.query:
workflow_inputs["query"] = trace_info.query
workflow_span = start_span_no_context(
name=TraceTaskName.WORKFLOW_TRACE.value,
span_type=SpanType.CHAIN,
inputs=workflow_inputs,
attributes=trace_info.metadata,
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
)
# Set reserved fields in trace-level metadata
trace_metadata = {}
if user_id := trace_info.metadata.get("user_id"):
trace_metadata[TraceMetadataKey.TRACE_USER] = user_id
if session_id := trace_info.conversation_id:
trace_metadata[TraceMetadataKey.TRACE_SESSION] = session_id
self._set_trace_metadata(workflow_span, trace_metadata)
try:
# Create child spans for workflow nodes
for node in self._get_workflow_nodes(trace_info.workflow_run_id):
inputs = None
attributes = {
"node_id": node.id,
"node_type": node.node_type,
"status": node.status,
"tenant_id": node.tenant_id,
"app_id": node.app_id,
"app_name": node.title,
}
if node.node_type in (NodeType.LLM, NodeType.QUESTION_CLASSIFIER):
inputs, llm_attributes = self._parse_llm_inputs_and_attributes(node)
attributes.update(llm_attributes)
elif node.node_type == NodeType.HTTP_REQUEST:
inputs = node.process_data # contains request URL
if not inputs:
inputs = json.loads(node.inputs) if node.inputs else {}
node_span = start_span_no_context(
name=node.title,
span_type=self._get_node_span_type(node.node_type),
parent_span=workflow_span,
inputs=inputs,
attributes=attributes,
start_time_ns=datetime_to_nanoseconds(node.created_at),
)
# Handle node errors
if node.status != "succeeded":
node_span.set_status(SpanStatusCode.ERROR)
node_span.add_event(
SpanEvent( # type: ignore[abstract]
name="exception",
attributes={
"exception.message": f"Node failed with status: {node.status}",
"exception.type": "Error",
"exception.stacktrace": f"Node failed with status: {node.status}",
},
)
)
# End node span
finished_at = node.created_at + timedelta(seconds=node.elapsed_time)
outputs = json.loads(node.outputs) if node.outputs else {}
if node.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
outputs = self._parse_knowledge_retrieval_outputs(outputs)
elif node.node_type == NodeType.LLM:
outputs = outputs.get("text", outputs)
node_span.end(
outputs=outputs,
end_time_ns=datetime_to_nanoseconds(finished_at),
)
# Handle workflow-level errors
if trace_info.error:
workflow_span.set_status(SpanStatusCode.ERROR)
workflow_span.add_event(
SpanEvent( # type: ignore[abstract]
name="exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
)
finally:
workflow_span.end(
outputs=trace_info.workflow_run_outputs,
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
)
def _parse_llm_inputs_and_attributes(self, node: WorkflowNodeExecutionModel) -> tuple[Any, dict]:
"""Parse LLM inputs and attributes from LLM workflow node"""
if node.process_data is None:
return {}, {}
try:
data = json.loads(node.process_data)
except (json.JSONDecodeError, TypeError):
return {}, {}
inputs = self._parse_prompts(data.get("prompts"))
attributes = {
"model_name": data.get("model_name"),
"model_provider": data.get("model_provider"),
"finish_reason": data.get("finish_reason"),
}
if hasattr(SpanAttributeKey, "MESSAGE_FORMAT"):
attributes[SpanAttributeKey.MESSAGE_FORMAT] = "dify"
if usage := data.get("usage"):
# Set reserved token usage attributes
attributes[SpanAttributeKey.CHAT_USAGE] = {
TokenUsageKey.INPUT_TOKENS: usage.get("prompt_tokens", 0),
TokenUsageKey.OUTPUT_TOKENS: usage.get("completion_tokens", 0),
TokenUsageKey.TOTAL_TOKENS: usage.get("total_tokens", 0),
}
# Store raw usage data as well as it includes more data like price
attributes["usage"] = usage
return inputs, attributes
def _parse_knowledge_retrieval_outputs(self, outputs: dict):
"""Parse KR outputs and attributes from KR workflow node"""
retrieved = outputs.get("result", [])
if not retrieved or not isinstance(retrieved, list):
return outputs
documents = []
for item in retrieved:
documents.append(Document(page_content=item.get("content", ""), metadata=item.get("metadata", {})))
return documents
def message_trace(self, trace_info: MessageTraceInfo):
"""Create span for CHATBOT message processing"""
if not trace_info.message_data:
return
file_list = cast(list[str], trace_info.file_list) or []
if message_file_data := trace_info.message_file_data:
base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
file_list.append(f"{base_url}/{message_file_data.url}")
span = start_span_no_context(
name=TraceTaskName.MESSAGE_TRACE.value,
span_type=SpanType.LLM,
inputs=self._parse_prompts(trace_info.inputs), # type: ignore[arg-type]
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"model_provider": trace_info.message_data.model_provider,
"model_id": trace_info.message_data.model_id,
"conversation_mode": trace_info.conversation_mode,
"file_list": file_list, # type: ignore[dict-item]
"total_price": trace_info.message_data.total_price,
**trace_info.metadata,
},
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
)
if hasattr(SpanAttributeKey, "MESSAGE_FORMAT"):
span.set_attribute(SpanAttributeKey.MESSAGE_FORMAT, "dify")
# Set token usage
span.set_attribute(
SpanAttributeKey.CHAT_USAGE,
{
TokenUsageKey.INPUT_TOKENS: trace_info.message_tokens or 0,
TokenUsageKey.OUTPUT_TOKENS: trace_info.answer_tokens or 0,
TokenUsageKey.TOTAL_TOKENS: trace_info.total_tokens or 0,
},
)
# Set reserved fields in trace-level metadata
trace_metadata = {}
if user_id := self._get_message_user_id(trace_info.metadata):
trace_metadata[TraceMetadataKey.TRACE_USER] = user_id
if session_id := trace_info.metadata.get("conversation_id"):
trace_metadata[TraceMetadataKey.TRACE_SESSION] = session_id
self._set_trace_metadata(span, trace_metadata)
if trace_info.error:
span.set_status(SpanStatusCode.ERROR)
span.add_event(
SpanEvent( # type: ignore[abstract]
name="error",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
)
span.end(
outputs=trace_info.message_data.answer,
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
)
def _get_message_user_id(self, metadata: dict) -> str | None:
if (end_user_id := metadata.get("from_end_user_id")) and (
end_user_data := db.session.query(EndUser).where(EndUser.id == end_user_id).first()
):
return end_user_data.session_id
return metadata.get("from_account_id") # type: ignore[return-value]
def tool_trace(self, trace_info: ToolTraceInfo):
span = start_span_no_context(
name=trace_info.tool_name,
span_type=SpanType.TOOL,
inputs=trace_info.tool_inputs, # type: ignore[arg-type]
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"metadata": trace_info.metadata, # type: ignore[dict-item]
"tool_config": trace_info.tool_config, # type: ignore[dict-item]
"tool_parameters": trace_info.tool_parameters, # type: ignore[dict-item]
},
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
)
# Handle tool errors
if trace_info.error:
span.set_status(SpanStatusCode.ERROR)
span.add_event(
SpanEvent( # type: ignore[abstract]
name="error",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
)
span.end(
outputs=trace_info.tool_outputs,
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
)
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
span = start_span_no_context(
name=TraceTaskName.MODERATION_TRACE.value,
span_type=SpanType.TOOL,
inputs=trace_info.inputs or {},
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"metadata": trace_info.metadata, # type: ignore[dict-item]
},
start_time_ns=datetime_to_nanoseconds(start_time),
)
span.end(
outputs={
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
},
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
span = start_span_no_context(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
span_type=SpanType.RETRIEVER,
inputs=trace_info.inputs,
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"metadata": trace_info.metadata, # type: ignore[dict-item]
},
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
)
span.end(outputs={"documents": trace_info.documents}, end_time_ns=datetime_to_nanoseconds(trace_info.end_time))
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
span = start_span_no_context(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
span_type=SpanType.TOOL,
inputs=trace_info.inputs,
attributes={
"message_id": trace_info.message_id, # type: ignore[dict-item]
"model_provider": trace_info.model_provider, # type: ignore[dict-item]
"model_id": trace_info.model_id, # type: ignore[dict-item]
"total_tokens": trace_info.total_tokens or 0, # type: ignore[dict-item]
},
start_time_ns=datetime_to_nanoseconds(start_time),
)
if trace_info.error:
span.set_status(SpanStatusCode.ERROR)
span.add_event(
SpanEvent( # type: ignore[abstract]
name="error",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
)
span.end(outputs=trace_info.suggested_question, end_time_ns=datetime_to_nanoseconds(end_time))
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
span = start_span_no_context(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
span_type=SpanType.CHAIN,
inputs=trace_info.inputs,
attributes={"message_id": trace_info.message_id}, # type: ignore[dict-item]
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
)
span.end(outputs=trace_info.outputs, end_time_ns=datetime_to_nanoseconds(trace_info.end_time))
def _get_workflow_nodes(self, workflow_run_id: str):
"""Helper method to get workflow nodes"""
workflow_nodes = (
db.session.query(
WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
WorkflowNodeExecutionModel.title,
WorkflowNodeExecutionModel.node_type,
WorkflowNodeExecutionModel.status,
WorkflowNodeExecutionModel.inputs,
WorkflowNodeExecutionModel.outputs,
WorkflowNodeExecutionModel.created_at,
WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
)
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.order_by(WorkflowNodeExecutionModel.created_at)
.all()
)
return workflow_nodes
def _get_node_span_type(self, node_type: str) -> str:
"""Map Dify node types to MLflow span types"""
node_type_mapping = {
NodeType.LLM: SpanType.LLM,
NodeType.QUESTION_CLASSIFIER: SpanType.LLM,
NodeType.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
NodeType.TOOL: SpanType.TOOL,
NodeType.CODE: SpanType.TOOL,
NodeType.HTTP_REQUEST: SpanType.TOOL,
NodeType.AGENT: SpanType.AGENT,
}
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
def _set_trace_metadata(self, span: Span, metadata: dict):
token = None
try:
# NB: Set span in context such that we can use update_current_trace() API
token = set_span_in_context(span)
update_current_trace(metadata=metadata)
finally:
if token:
detach_span_from_context(token)
def _parse_prompts(self, prompts):
"""Postprocess prompts format to be standard chat messages"""
if isinstance(prompts, str):
return prompts
elif isinstance(prompts, dict):
return self._parse_single_message(prompts)
elif isinstance(prompts, list):
messages = [self._parse_single_message(item) for item in prompts]
messages = self._resolve_tool_call_ids(messages)
return messages
return prompts # Fallback to original format
def _parse_single_message(self, item: dict):
"""Postprocess single message format to be standard chat message"""
role = item.get("role", "user")
msg = {"role": role, "content": item.get("text", "")}
if (
(tool_calls := item.get("tool_calls"))
# Tool message does not contain tool calls normally
and role != "tool"
):
msg["tool_calls"] = tool_calls
if files := item.get("files"):
msg["files"] = files
return msg
def _resolve_tool_call_ids(self, messages: list[dict]):
"""
The tool call message from Dify does not contain tool call ids, which is not
ideal for debugging. This method resolves the tool call ids by matching the
tool call name and parameters with the tool instruction messages.
"""
tool_call_ids = []
for msg in messages:
if tool_calls := msg.get("tool_calls"):
tool_call_ids = [t["id"] for t in tool_calls]
if msg["role"] == "tool":
# Get the tool call id in the order of the tool call messages
# assuming Dify runs tools sequentially
if tool_call_ids:
msg["tool_call_id"] = tool_call_ids.pop(0)
return messages
def api_check(self):
"""Simple connection test"""
try:
mlflow.search_experiments(max_results=1)
return True
except Exception as e:
raise ValueError(f"MLflow connection failed: {str(e)}")
def get_project_url(self):
return self._project_url

View File

@ -120,6 +120,26 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
"other_keys": ["endpoint", "app_name"], "other_keys": ["endpoint", "app_name"],
"trace_instance": AliyunDataTrace, "trace_instance": AliyunDataTrace,
} }
case TracingProviderEnum.MLFLOW:
from core.ops.entities.config_entity import MLflowConfig
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
return {
"config_class": MLflowConfig,
"secret_keys": ["password"],
"other_keys": ["tracking_uri", "experiment_id", "username"],
"trace_instance": MLflowDataTrace,
}
case TracingProviderEnum.DATABRICKS:
from core.ops.entities.config_entity import DatabricksConfig
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
return {
"config_class": DatabricksConfig,
"secret_keys": ["personal_access_token", "client_secret"],
"other_keys": ["host", "client_id", "experiment_id"],
"trace_instance": MLflowDataTrace,
}
case TracingProviderEnum.TENCENT: case TracingProviderEnum.TENCENT:
from core.ops.entities.config_entity import TencentConfig from core.ops.entities.config_entity import TencentConfig
@ -274,6 +294,8 @@ class OpsTraceManager:
raise ValueError("App not found") raise ValueError("App not found")
tenant_id = app.tenant_id tenant_id = app.tenant_id
if trace_config_data.tracing_config is None:
raise ValueError("Tracing config cannot be None.")
decrypt_tracing_config = cls.decrypt_tracing_config( decrypt_tracing_config = cls.decrypt_tracing_config(
tenant_id, tracing_provider, trace_config_data.tracing_config tenant_id, tracing_provider, trace_config_data.tracing_config
) )

View File

@ -147,3 +147,14 @@ def validate_project_name(project: str, default_name: str) -> str:
return default_name return default_name
return project.strip() return project.strip()
def validate_integer_id(id_str: str) -> str:
"""
Validate and normalize integer ID
"""
id_str = id_str.strip()
if not id_str.isdigit():
raise ValueError("ID must be a valid integer")
return id_str

View File

@ -1,12 +1,20 @@
import logging import logging
import os import os
import uuid import uuid
from datetime import datetime, timedelta from datetime import UTC, datetime, timedelta
from typing import Any, cast from typing import Any, cast
import wandb import wandb
import weave import weave
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from weave.trace_server.trace_server_interface import (
CallEndReq,
CallStartReq,
EndedCallSchemaForInsert,
StartedCallSchemaForInsert,
SummaryInsertMap,
TraceStatus,
)
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.config_entity import WeaveConfig
@ -57,6 +65,7 @@ class WeaveDataTrace(BaseTraceInstance):
) )
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.calls: dict[str, Any] = {} self.calls: dict[str, Any] = {}
self.project_id = f"{self.weave_client.entity}/{self.weave_client.project}"
def get_project_url( def get_project_url(
self, self,
@ -424,6 +433,13 @@ class WeaveDataTrace(BaseTraceInstance):
logger.debug("Weave API check failed: %s", str(e)) logger.debug("Weave API check failed: %s", str(e))
raise ValueError(f"Weave API check failed: {str(e)}") raise ValueError(f"Weave API check failed: {str(e)}")
def _normalize_time(self, dt: datetime | None) -> datetime:
if dt is None:
return datetime.now(UTC)
if dt.tzinfo is None:
return dt.replace(tzinfo=UTC)
return dt
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None): def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
inputs = run_data.inputs inputs = run_data.inputs
if inputs is None: if inputs is None:
@ -437,19 +453,71 @@ class WeaveDataTrace(BaseTraceInstance):
elif not isinstance(attributes, dict): elif not isinstance(attributes, dict):
attributes = {"attributes": str(attributes)} attributes = {"attributes": str(attributes)}
call = self.weave_client.create_call( start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
op=run_data.op, started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
inputs=inputs, trace_id = attributes.get("trace_id") if isinstance(attributes, dict) else None
if trace_id is None:
trace_id = run_data.id
call_start_req = CallStartReq(
start=StartedCallSchemaForInsert(
project_id=self.project_id,
id=run_data.id,
op_name=str(run_data.op),
trace_id=trace_id,
parent_id=parent_run_id,
started_at=started_at,
attributes=attributes, attributes=attributes,
inputs=inputs,
wb_user_id=None,
) )
self.calls[run_data.id] = call )
if parent_run_id: self.weave_client.server.call_start(call_start_req)
self.calls[run_data.id].parent_id = parent_run_id self.calls[run_data.id] = {"trace_id": trace_id, "parent_id": parent_run_id}
def finish_call(self, run_data: WeaveTraceModel): def finish_call(self, run_data: WeaveTraceModel):
call = self.calls.get(run_data.id) call_meta = self.calls.get(run_data.id)
if call: if not call_meta:
exception = Exception(run_data.exception) if run_data.exception else None
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=exception)
else:
raise ValueError(f"Call with id {run_data.id} not found") raise ValueError(f"Call with id {run_data.id} not found")
attributes = run_data.attributes
if attributes is None:
attributes = {}
elif not isinstance(attributes, dict):
attributes = {"attributes": str(attributes)}
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
end_time = attributes.get("end_time") if isinstance(attributes, dict) else None
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
ended_at = self._normalize_time(end_time if isinstance(end_time, datetime) else None)
elapsed_ms = int((ended_at - started_at).total_seconds() * 1000)
if elapsed_ms < 0:
elapsed_ms = 0
status_counts = {
TraceStatus.SUCCESS: 0,
TraceStatus.ERROR: 0,
}
if run_data.exception:
status_counts[TraceStatus.ERROR] = 1
else:
status_counts[TraceStatus.SUCCESS] = 1
summary: dict[str, Any] = {
"status_counts": status_counts,
"weave": {"latency_ms": elapsed_ms},
}
exception_str = str(run_data.exception) if run_data.exception else None
call_end_req = CallEndReq(
end=EndedCallSchemaForInsert(
project_id=self.project_id,
id=run_data.id,
ended_at=ended_at,
exception=exception_str,
output=run_data.outputs,
summary=cast(SummaryInsertMap, summary),
)
)
self.weave_client.server.call_end(call_end_req)

View File

@ -309,11 +309,12 @@ class ProviderManager:
(model for model in available_models if model.model == "gpt-4"), available_models[0] (model for model in available_models if model.model == "gpt-4"), available_models[0]
) )
default_model = TenantDefaultModel() default_model = TenantDefaultModel(
default_model.tenant_id = tenant_id tenant_id=tenant_id,
default_model.model_type = model_type.to_origin_model_type() model_type=model_type.to_origin_model_type(),
default_model.provider_name = available_model.provider.provider provider_name=available_model.provider.provider,
default_model.model_name = available_model.model model_name=available_model.model,
)
db.session.add(default_model) db.session.add(default_model)
db.session.commit() db.session.commit()

View File

@ -22,6 +22,18 @@ logger = logging.getLogger(__name__)
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
T = TypeVar("T", bound="MatrixoneVector")
def ensure_client(func: Callable[Concatenate[T, P], R]):
@wraps(func)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)
return wrapper
class MatrixoneConfig(BaseModel): class MatrixoneConfig(BaseModel):
host: str = "localhost" host: str = "localhost"
@ -206,19 +218,6 @@ class MatrixoneVector(BaseVector):
self.client.delete() self.client.delete()
T = TypeVar("T", bound=MatrixoneVector)
def ensure_client(func: Callable[Concatenate[T, P], R]):
@wraps(func)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)
return wrapper
class MatrixoneVectorFactory(AbstractVectorFactory): class MatrixoneVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector: def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
if dataset.index_struct_dict: if dataset.index_struct_dict:

View File

@ -7,8 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Union, cast from typing import Any, Union, cast
from flask import Flask, current_app from flask import Flask, current_app
from sqlalchemy import Float, and_, or_, select, text from sqlalchemy import and_, or_, select
from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import ( from core.app.app_config.entities import (
DatasetEntity, DatasetEntity,
@ -1023,60 +1022,55 @@ class DatasetRetrieval:
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
): ):
if value is None and condition not in ("empty", "not empty"): if value is None and condition not in ("empty", "not empty"):
return return filters
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
key = f"{metadata_name}_{sequence}"
key_value = f"{metadata_name}_{sequence}_value"
match condition: match condition:
case "contains": case "contains":
filters.append( filters.append(json_field.like(f"%{value}%"))
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}%"}
)
)
case "not contains": case "not contains":
filters.append( filters.append(json_field.notlike(f"%{value}%"))
(text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}%"}
)
)
case "start with": case "start with":
filters.append( filters.append(json_field.like(f"{value}%"))
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"{value}%"}
)
)
case "end with": case "end with":
filters.append( filters.append(json_field.like(f"%{value}"))
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
**{key: metadata_name, key_value: f"%{value}"}
)
)
case "is" | "=": case "is" | "=":
if isinstance(value, str): if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"') filters.append(json_field == value)
else: elif isinstance(value, (int, float)):
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) == value) filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() == value)
case "is not" | "": case "is not" | "":
if isinstance(value, str): if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"') filters.append(json_field != value)
else: elif isinstance(value, (int, float)):
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) != value) filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() != value)
case "empty": case "empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None)) filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
case "not empty": case "not empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None)) filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
case "before" | "<": case "before" | "<":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) < value) filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() < value)
case "after" | ">": case "after" | ">":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) > value) filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() > value)
case "" | "<=": case "" | "<=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) <= value) filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() <= value)
case "" | ">=": case "" | ">=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) >= value) filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
case _: case _:
pass pass
return filters return filters
def _fetch_model_config( def _fetch_model_config(

Some files were not shown because too many files have changed in this diff Show More