Merge branch 'main' into feat/mcp-06-18

This commit is contained in:
Novice 2025-10-20 10:29:09 +08:00
commit c22ba3f537
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
542 changed files with 11548 additions and 7438 deletions

View File

@ -39,25 +39,11 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: uv sync --project api --dev run: uv sync --project api --dev
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Run pyrefly check - name: Run pyrefly check
run: | run: |
cd api cd api
uv add --dev pyrefly uv add --dev pyrefly
uv run pyrefly check || true uv run pyrefly check || true
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
- name: Run dify config tests - name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py run: uv run --project api dev/pytest/pytest_config_tests.py
@ -93,3 +79,19 @@ jobs:
- name: Run TestContainers - name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY

View File

@ -14,4 +14,4 @@ yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.ya
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss" echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"

View File

@ -14,7 +14,7 @@ The codebase is split into:
- Run backend CLI commands through `uv run --project api <command>`. - Run backend CLI commands through `uv run --project api <command>`.
- Backend QA gate requires passing `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before review. - Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks. - Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.

View File

@ -129,8 +129,18 @@ Star Dify on GitHub and be instantly notified of new releases.
## Advanced Setup ## Advanced Setup
### Custom configurations
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
### Metrics Monitoring with Grafana
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
- [Grafana Dashboard by @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard)
### Deployment with Kubernetes
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)

View File

@ -55,3 +55,12 @@ else:
"properties", "properties",
} }
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions) DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
COOKIE_NAME_ACCESS_TOKEN = "access_token"
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
COOKIE_NAME_PASSPORT = "passport"
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
HEADER_NAME_APP_CODE = "X-App-Code"
HEADER_NAME_PASSPORT = "X-App-Passport"

View File

@ -15,6 +15,7 @@ from constants.languages import supported_language
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.wraps import only_edition_cloud from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp from models.model import App, InstalledApp, RecommendedApp
@ -24,19 +25,9 @@ def admin_required(view: Callable[P, R]):
if not dify_config.ADMIN_API_KEY: if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.") raise Unauthorized("API key is invalid.")
auth_header = request.headers.get("Authorization") auth_token = extract_access_token(request)
if auth_header is None: if not auth_token:
raise Unauthorized("Authorization header is missing.") raise Unauthorized("Authorization header is missing.")
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if auth_token != dify_config.ADMIN_API_KEY: if auth_token != dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.") raise Unauthorized("API key is invalid.")
@ -70,15 +61,17 @@ class InsertExploreAppListApi(Resource):
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("desc", type=str, location="json") .add_argument("app_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("copyright", type=str, location="json") .add_argument("desc", type=str, location="json")
parser.add_argument("privacy_policy", type=str, location="json") .add_argument("copyright", type=str, location="json")
parser.add_argument("custom_disclaimer", type=str, location="json") .add_argument("privacy_policy", type=str, location="json")
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json") .add_argument("custom_disclaimer", type=str, location="json")
parser.add_argument("category", type=str, required=True, nullable=False, location="json") .add_argument("language", type=supported_language, required=True, nullable=False, location="json")
parser.add_argument("position", type=int, required=True, nullable=False, location="json") .add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("position", type=int, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()

View File

@ -1,5 +1,4 @@
import flask_restx import flask_restx
from flask import Response
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus from flask_restx._http import HTTPStatus
from sqlalchemy import select from sqlalchemy import select
@ -13,7 +12,7 @@ from models.dataset import Dataset
from models.model import ApiToken, App from models.model import ApiToken, App
from . import api, console_ns from . import api, console_ns
from .wraps import account_initialization_required, setup_required from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = { api_key_fields = {
"id": fields.String, "id": fields.String,
@ -68,14 +67,12 @@ class BaseApiKeyListResource(Resource):
return {"items": keys} return {"items": keys}
@marshal_with(api_key_fields) @marshal_with(api_key_fields)
@edit_permission_required
def post(self, resource_id): def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model) _get_resource(resource_id, current_tenant_id, self.resource_model)
if not current_user.has_edit_permission:
raise Forbidden()
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
@ -156,11 +153,6 @@ class AppApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for an app""" """Create a new API key for an app"""
return super().post(resource_id) return super().post(resource_id)
def after_request(self, resp: Response):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app" resource_type = "app"
resource_model = App resource_model = App
resource_id_field = "app_id" resource_id_field = "app_id"
@ -177,11 +169,6 @@ class AppApiKeyResource(BaseApiKeyResource):
"""Delete an API key for an app""" """Delete an API key for an app"""
return super().delete(resource_id, api_key_id) return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app" resource_type = "app"
resource_model = App resource_model = App
resource_id_field = "app_id" resource_id_field = "app_id"
@ -206,11 +193,6 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for a dataset""" """Create a new API key for a dataset"""
return super().post(resource_id) return super().post(resource_id)
def after_request(self, resp: Response):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset" resource_type = "dataset"
resource_model = Dataset resource_model = Dataset
resource_id_field = "dataset_id" resource_id_field = "dataset_id"
@ -227,11 +209,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
"""Delete an API key for a dataset""" """Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id) return super().delete(resource_id, api_key_id)
def after_request(self, resp: Response):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset" resource_type = "dataset"
resource_model = Dataset resource_model = Dataset
resource_id_field = "dataset_id" resource_id_field = "dataset_id"

View File

@ -25,11 +25,13 @@ class AdvancedPromptTemplateList(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("app_mode", type=str, required=True, location="args") reqparse.RequestParser()
parser.add_argument("model_mode", type=str, required=True, location="args") .add_argument("app_mode", type=str, required=True, location="args")
parser.add_argument("has_context", type=str, required=False, default="true", location="args") .add_argument("model_mode", type=str, required=True, location="args")
parser.add_argument("model_name", type=str, required=True, location="args") .add_argument("has_context", type=str, required=False, default="true", location="args")
.add_argument("model_name", type=str, required=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args) return AdvancedPromptTemplateService.get_prompt(args)

View File

@ -27,9 +27,11 @@ class AgentLogApi(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model): def get(self, app_model):
"""Get agent logs""" """Get agent logs"""
parser = reqparse.RequestParser() parser = (
parser.add_argument("message_id", type=uuid_value, required=True, location="args") reqparse.RequestParser()
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args") .add_argument("message_id", type=uuid_value, required=True, location="args")
.add_argument("conversation_id", type=uuid_value, required=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()

View File

@ -2,13 +2,13 @@ from typing import Literal
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required,
setup_required, setup_required,
) )
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -16,7 +16,7 @@ from fields.annotation_fields import (
annotation_fields, annotation_fields,
annotation_hit_history_fields, annotation_hit_history_fields,
) )
from libs.login import current_account_with_tenant, login_required from libs.login import login_required
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
@ -41,17 +41,15 @@ class AnnotationReplyActionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]): def post(self, app_id, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("score_threshold", required=True, type=float, location="json") reqparse.RequestParser()
parser.add_argument("embedding_provider_name", required=True, type=str, location="json") .add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json") .add_argument("embedding_provider_name", required=True, type=str, location="json")
.add_argument("embedding_model_name", required=True, type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if action == "enable": if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_id) result = AppAnnotationService.enable_app_annotation(args, app_id)
@ -70,12 +68,8 @@ class AppAnnotationSettingDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_id): def get(self, app_id):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id) result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
return result, 200 return result, 200
@ -101,17 +95,12 @@ class AppAnnotationSettingUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, app_id, annotation_setting_id): def post(self, app_id, annotation_setting_id):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id) annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("score_threshold", required=True, type=float, location="json")
args = parser.parse_args() args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
@ -129,12 +118,8 @@ class AnnotationReplyActionStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id, action): def get(self, app_id, job_id, action):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
job_id = str(job_id) job_id = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
cache_result = redis_client.get(app_annotation_job_key) cache_result = redis_client.get(app_annotation_job_key)
@ -166,12 +151,8 @@ class AnnotationApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_id): def get(self, app_id):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str) keyword = request.args.get("keyword", default="", type=str)
@ -207,16 +188,14 @@ class AnnotationApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
@edit_permission_required
def post(self, app_id): def post(self, app_id):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("question", required=True, type=str, location="json") reqparse.RequestParser()
parser.add_argument("answer", required=True, type=str, location="json") .add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation return annotation
@ -224,12 +203,8 @@ class AnnotationApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, app_id): def delete(self, app_id):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
# Use request.args.getlist to get annotation_ids array directly # Use request.args.getlist to get annotation_ids array directly
@ -262,12 +237,8 @@ class AnnotationExportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_id): def get(self, app_id):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)} response = {"data": marshal(annotation_list, annotation_fields)}
@ -286,18 +257,16 @@ class AnnotationUpdateDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_id, annotation_id): def post(self, app_id, annotation_id):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("question", required=True, type=str, location="json") reqparse.RequestParser()
parser.add_argument("answer", required=True, type=str, location="json") .add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation return annotation
@ -305,12 +274,8 @@ class AnnotationUpdateDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, app_id, annotation_id): def delete(self, app_id, annotation_id):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id) AppAnnotationService.delete_app_annotation(app_id, annotation_id)
@ -329,12 +294,8 @@ class AnnotationBatchImportApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id): def post(self, app_id):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
# check file # check file
if "file" not in request.files: if "file" not in request.files:
@ -362,12 +323,8 @@ class AnnotationBatchImportStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id): def get(self, app_id, job_id):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
job_id = str(job_id) job_id = str(job_id)
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
@ -399,12 +356,8 @@ class AnnotationHitHistoryListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_id, annotation_id): def get(self, app_id, annotation_id):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
app_id = str(app_id) app_id = str(app_id)

View File

@ -1,7 +1,5 @@
import uuid import uuid
from typing import cast
from flask_login import current_user
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -12,15 +10,16 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required,
enterprise_license_required, enterprise_license_required,
setup_required, setup_required,
) )
from core.ops.ops_trace_manager import OpsTraceManager from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length from libs.validators import validate_description_length
from models import Account, App from models import App
from services.app_dsl_service import AppDslService, ImportMode from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
@ -56,6 +55,7 @@ class AppListApi(Resource):
@enterprise_license_required @enterprise_license_required
def get(self): def get(self):
"""Get app list""" """Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
def uuid_list(value): def uuid_list(value):
try: try:
@ -63,34 +63,36 @@ class AppListApi(Resource):
except ValueError: except ValueError:
abort(400, message="Invalid UUID format in tag_ids.") abort(400, message="Invalid UUID format in tag_ids.")
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument( .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
"mode", .add_argument(
type=str, "mode",
choices=[ type=str,
"completion", choices=[
"chat", "completion",
"advanced-chat", "chat",
"workflow", "advanced-chat",
"agent-chat", "workflow",
"channel", "agent-chat",
"all", "channel",
], "all",
default="all", ],
location="args", default="all",
required=False, location="args",
required=False,
)
.add_argument("name", type=str, location="args", required=False)
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
) )
parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
args = parser.parse_args() args = parser.parse_args()
# get app list # get app list
app_service = AppService() app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args) app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
if not app_pagination: if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@ -129,30 +131,26 @@ class AppListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self): def post(self):
"""Create app""" """Create app"""
parser = reqparse.RequestParser() current_user, current_tenant_id = current_account_with_tenant()
parser.add_argument("name", type=str, required=True, location="json") parser = (
parser.add_argument("description", type=validate_description_length, location="json") reqparse.RequestParser()
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") .add_argument("name", type=str, required=True, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_background", type=str, location="json") .add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if "mode" not in args or args["mode"] is None: if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required") raise BadRequest("mode is required")
app_service = AppService() app_service = AppService()
if not isinstance(current_user, Account): app = app_service.create_app(current_tenant_id, args, current_user)
raise ValueError("current_user must be an Account instance")
if current_user.current_tenant_id is None:
raise ValueError("current_user.current_tenant_id cannot be None")
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
return app, 201 return app, 201
@ -205,21 +203,20 @@ class AppApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@edit_permission_required
@marshal_with(app_detail_fields_with_site) @marshal_with(app_detail_fields_with_site)
def put(self, app_model): def put(self, app_model):
"""Update app""" """Update app"""
# The role of the current user in the ta table must be admin, owner, or editor parser = (
if not current_user.is_editor: reqparse.RequestParser()
raise Forbidden() .add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=validate_description_length, location="json")
parser = reqparse.RequestParser() .add_argument("icon_type", type=str, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json") .add_argument("icon", type=str, location="json")
parser.add_argument("description", type=validate_description_length, location="json") .add_argument("icon_background", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("use_icon_as_answer_icon", type=bool, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("max_active_requests", type=int, location="json")
parser.add_argument("icon_background", type=str, location="json") )
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
parser.add_argument("max_active_requests", type=int, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -248,12 +245,9 @@ class AppApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, app_model): def delete(self, app_model):
"""Delete app""" """Delete app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
app_service = AppService() app_service = AppService()
app_service.delete_app(app_model) app_service.delete_app(app_model)
@ -283,27 +277,28 @@ class AppCopyApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@edit_permission_required
@marshal_with(app_detail_fields_with_site) @marshal_with(app_detail_fields_with_site)
def post(self, app_model): def post(self, app_model):
"""Copy app""" """Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("name", type=str, location="json") reqparse.RequestParser()
parser.add_argument("description", type=validate_description_length, location="json") .add_argument("name", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("icon_type", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json") .add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
account = cast(Account, current_user)
result = import_service.import_app( result = import_service.import_app(
account=account, account=current_user,
import_mode=ImportMode.YAML_CONTENT, import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content, yaml_content=yaml_content,
name=args.get("name"), name=args.get("name"),
@ -340,16 +335,15 @@ class AppExportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_model): def get(self, app_model):
"""Export app""" """Export app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params # Add include_secret params
parser = reqparse.RequestParser() parser = (
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") reqparse.RequestParser()
parser.add_argument("workflow_id", type=str, location="args") .add_argument("include_secret", type=inputs.boolean, default=False, location="args")
.add_argument("workflow_id", type=str, location="args")
)
args = parser.parse_args() args = parser.parse_args()
return { return {
@ -371,13 +365,9 @@ class AppNameApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -408,14 +398,13 @@ class AppIconApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor parser = (
if not current_user.is_editor: reqparse.RequestParser()
raise Forbidden() .add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
parser = reqparse.RequestParser() )
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -441,13 +430,9 @@ class AppSiteStatus(Resource):
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enable_site", type=bool, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -475,11 +460,11 @@ class AppApiStatus(Resource):
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
parser.add_argument("enable_api", type=bool, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -520,13 +505,14 @@ class AppTraceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, app_id): def post(self, app_id):
# add app trace # add app trace
if not current_user.is_editor: parser = (
raise Forbidden() reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("enabled", type=bool, required=True, location="json")
parser.add_argument("enabled", type=bool, required=True, location="json") .add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument("tracing_provider", type=str, required=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
OpsTraceManager.update_app_tracing_config( OpsTraceManager.update_app_tracing_config(

View File

@ -1,11 +1,11 @@
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required,
setup_required, setup_required,
) )
from extensions.ext_database import db from extensions.ext_database import db
@ -26,22 +26,22 @@ class AppImportApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(app_import_fields) @marshal_with(app_import_fields)
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self): def post(self):
# Check user role first # Check user role first
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission: parser = (
raise Forbidden() reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
parser = reqparse.RequestParser() .add_argument("yaml_content", type=str, location="json")
parser.add_argument("mode", type=str, required=True, location="json") .add_argument("yaml_url", type=str, location="json")
parser.add_argument("yaml_content", type=str, location="json") .add_argument("name", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json") .add_argument("description", type=str, location="json")
parser.add_argument("name", type=str, location="json") .add_argument("icon_type", type=str, location="json")
parser.add_argument("description", type=str, location="json") .add_argument("icon", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("icon_background", type=str, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("app_id", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json") )
parser.add_argument("app_id", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
# Create service with session # Create service with session
@ -80,11 +80,10 @@ class AppImportConfirmApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_import_fields) @marshal_with(app_import_fields)
@edit_permission_required
def post(self, import_id): def post(self, import_id):
# Check user role first # Check user role first
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
@ -107,11 +106,8 @@ class AppImportCheckDependenciesApi(Resource):
@get_app_model @get_app_model
@account_initialization_required @account_initialization_required
@marshal_with(app_import_check_dependencies_fields) @marshal_with(app_import_check_dependencies_fields)
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model) result = import_service.check_dependencies(app_model=app_model)

View File

@ -111,11 +111,13 @@ class ChatMessageTextApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, app_model: App): def post(self, app_model: App):
try: try:
parser = reqparse.RequestParser() parser = (
parser.add_argument("message_id", type=str, location="json") reqparse.RequestParser()
parser.add_argument("text", type=str, location="json") .add_argument("message_id", type=str, location="json")
parser.add_argument("voice", type=str, location="json") .add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json") .add_argument("voice", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args() args = parser.parse_args()
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
@ -166,8 +168,7 @@ class TextModesApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
try: try:
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
parser.add_argument("language", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
response = AudioService.transcript_tts_voices( response = AudioService.transcript_tts_voices(

View File

@ -2,7 +2,7 @@ import logging
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.console import api, console_ns from controllers.console import api, console_ns
@ -15,7 +15,7 @@ from controllers.console.app.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -64,13 +64,15 @@ class CompletionMessageApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model): def post(self, app_model):
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, location="json") reqparse.RequestParser()
parser.add_argument("query", type=str, location="json", default="") .add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json") .add_argument("query", type=str, location="json", default="")
parser.add_argument("model_config", type=dict, required=True, location="json") .add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") .add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] != "blocking" streaming = args["response_mode"] != "blocking"
@ -151,22 +153,19 @@ class ChatMessageApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model): def post(self, app_model):
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
if not current_user.has_edit_permission: .add_argument("query", type=str, required=True, location="json")
raise Forbidden() .add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
parser = reqparse.RequestParser() .add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("inputs", type=dict, required=True, location="json") .add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("query", type=str, required=True, location="json") .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("files", type=list, required=False, location="json") .add_argument("retriever_from", type=str, required=False, default="dev", location="json")
parser.add_argument("model_config", type=dict, required=True, location="json") )
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] != "blocking" streaming = args["response_mode"] != "blocking"

View File

@ -1,17 +1,16 @@
from datetime import datetime from datetime import datetime
import pytz # pip install pytz import pytz
import sqlalchemy as sa import sqlalchemy as sa
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy import func, or_ from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from fields.conversation_fields import ( from fields.conversation_fields import (
@ -22,8 +21,8 @@ from fields.conversation_fields import (
) )
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account, Conversation, EndUser, Message, MessageAnnotation from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode from models.model import AppMode
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
@ -57,18 +56,24 @@ class CompletionConversationApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_fields) @marshal_with(conversation_pagination_fields)
@edit_permission_required
def get(self, app_model): def get(self, app_model):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
parser = reqparse.RequestParser() reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") .add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument( .add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" "annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
) )
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
query = sa.select(Conversation).where( query = sa.select(Conversation).where(
@ -84,6 +89,7 @@ class CompletionConversationApi(Resource):
) )
account = current_user account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -137,9 +143,8 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_fields) @marshal_with(conversation_message_detail_fields)
@edit_permission_required
def get(self, app_model, conversation_id): def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id) return _get_conversation(app_model, conversation_id)
@ -154,14 +159,12 @@ class CompletionConversationDetailApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def delete(self, app_model, conversation_id): def delete(self, app_model, conversation_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -206,26 +209,32 @@ class ChatConversationApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_with_summary_pagination_fields) @marshal_with(conversation_with_summary_pagination_fields)
@edit_permission_required
def get(self, app_model): def get(self, app_model):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
parser = reqparse.RequestParser() reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") .add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument( .add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" "annotation_status",
) type=str,
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args") choices=["annotated", "not_annotated", "all"],
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args") default="all",
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") location="args",
parser.add_argument( )
"sort_by", .add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
type=str, .add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
choices=["created_at", "-created_at", "updated_at", "-updated_at"], .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
required=False, .add_argument(
default="-updated_at", "sort_by",
location="args", type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
) )
args = parser.parse_args() args = parser.parse_args()
@ -260,6 +269,7 @@ class ChatConversationApi(Resource):
) )
account = current_user account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -341,9 +351,8 @@ class ChatConversationDetailApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_detail_fields) @marshal_with(conversation_detail_fields)
@edit_permission_required
def get(self, app_model, conversation_id): def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id) return _get_conversation(app_model, conversation_id)
@ -358,14 +367,12 @@ class ChatConversationDetailApi(Resource):
@login_required @login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, app_model, conversation_id): def delete(self, app_model, conversation_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -374,6 +381,7 @@ class ChatConversationDetailApi(Resource):
def _get_conversation(app_model, conversation_id): def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)

View File

@ -29,8 +29,7 @@ class ConversationVariablesApi(Resource):
@get_app_model(mode=AppMode.ADVANCED_CHAT) @get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_fields) @marshal_with(paginated_conversation_variable_fields)
def get(self, app_model): def get(self, app_model):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
parser.add_argument("conversation_id", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
stmt = ( stmt = (

View File

@ -42,10 +42,12 @@ class RuleGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") .add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") .add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
@ -92,11 +94,13 @@ class RuleCodeGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") .add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") .add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json") .add_argument("no_variable", type=bool, required=True, default=False, location="json")
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
@ -139,9 +143,11 @@ class RuleStructuredOutputGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") .add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
@ -188,14 +194,16 @@ class InstructionGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("flow_id", type=str, required=True, default="", location="json") reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=False, default="", location="json") .add_argument("flow_id", type=str, required=True, default="", location="json")
parser.add_argument("current", type=str, required=False, default="", location="json") .add_argument("node_id", type=str, required=False, default="", location="json")
parser.add_argument("language", type=str, required=False, default="javascript", location="json") .add_argument("current", type=str, required=False, default="", location="json")
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") .add_argument("language", type=str, required=False, default="javascript", location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") .add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json") .add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("ideal_output", type=str, required=False, default="", location="json")
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
code_template = ( code_template = (
@ -293,8 +301,7 @@ class InstructionGenerationTemplateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
parser.add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args() args = parser.parse_args()
match args["type"]: match args["type"]:
case "prompt": case "prompt":

View File

@ -1,16 +1,15 @@
import json import json
from enum import StrEnum from enum import StrEnum
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_server_fields from fields.app_fields import app_server_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer from models.model import AppMCPServer
@ -25,9 +24,9 @@ class AppMCPServerController(Resource):
@api.doc(description="Get MCP server configuration for an application") @api.doc(description="Get MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields) @api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
@setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@setup_required
@get_app_model @get_app_model
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
def get(self, app_model): def get(self, app_model):
@ -48,17 +47,19 @@ class AppMCPServerController(Resource):
) )
@api.response(201, "MCP server configuration created successfully", app_server_fields) @api.response(201, "MCP server configuration created successfully", app_server_fields)
@api.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@login_required
@setup_required
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
@edit_permission_required
def post(self, app_model): def post(self, app_model):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise NotFound() parser = (
parser = reqparse.RequestParser() reqparse.RequestParser()
parser.add_argument("description", type=str, required=False, location="json") .add_argument("description", type=str, required=False, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json") .add_argument("parameters", type=dict, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
description = args.get("description") description = args.get("description")
@ -71,7 +72,7 @@ class AppMCPServerController(Resource):
parameters=json.dumps(args["parameters"], ensure_ascii=False), parameters=json.dumps(args["parameters"], ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE, status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id, app_id=app_model.id,
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
server_code=AppMCPServer.generate_server_code(16), server_code=AppMCPServer.generate_server_code(16),
) )
db.session.add(server) db.session.add(server)
@ -95,19 +96,20 @@ class AppMCPServerController(Resource):
@api.response(200, "MCP server configuration updated successfully", app_server_fields) @api.response(200, "MCP server configuration updated successfully", app_server_fields)
@api.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@api.response(404, "Server not found") @api.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model @get_app_model
@login_required
@setup_required
@account_initialization_required
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
@edit_permission_required
def put(self, app_model): def put(self, app_model):
if not current_user.is_editor: parser = (
raise NotFound() reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("id", type=str, required=True, location="json")
parser.add_argument("id", type=str, required=True, location="json") .add_argument("description", type=str, required=False, location="json")
parser.add_argument("description", type=str, required=False, location="json") .add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json") .add_argument("status", type=str, required=False, location="json")
parser.add_argument("status", type=str, required=False, location="json") )
args = parser.parse_args() args = parser.parse_args()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first() server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
if not server: if not server:
@ -142,13 +144,13 @@ class AppMCPServerRefreshController(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
@edit_permission_required
def get(self, server_id): def get(self, server_id):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise NotFound()
server = ( server = (
db.session.query(AppMCPServer) db.session.query(AppMCPServer)
.where(AppMCPServer.id == server_id) .where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_user.current_tenant_id) .where(AppMCPServer.tenant_id == current_tenant_id)
.first() .first()
) )
if not server: if not server:

View File

@ -3,7 +3,7 @@ import logging
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy import exists, select from sqlalchemy import exists, select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
@ -17,6 +17,7 @@ from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDi
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required,
setup_required, setup_required,
) )
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -26,8 +27,7 @@ from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
@ -56,19 +56,19 @@ class ChatMessageListApi(Resource):
) )
@api.response(200, "Success", message_infinite_scroll_pagination_fields) @api.response(200, "Success", message_infinite_scroll_pagination_fields)
@api.response(404, "Conversation not found") @api.response(404, "Conversation not found")
@setup_required
@login_required @login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required @account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
@edit_permission_required
def get(self, app_model): def get(self, app_model):
if not isinstance(current_user, Account) or not current_user.has_edit_permission: parser = (
raise Forbidden() reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser = reqparse.RequestParser() .add_argument("first_id", type=uuid_value, location="args")
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("first_id", type=uuid_value, location="args") )
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
conversation = ( conversation = (
@ -154,12 +154,13 @@ class MessageFeedbackApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, app_model): def post(self, app_model):
if current_user is None: current_user, _ = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("message_id", required=True, type=uuid_value, location="json") reqparse.RequestParser()
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") .add_argument("message_id", required=True, type=uuid_value, location="json")
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
)
args = parser.parse_args() args = parser.parse_args()
message_id = str(args["message_id"]) message_id = str(args["message_id"])
@ -211,23 +212,21 @@ class MessageAnnotationApi(Resource):
) )
@api.response(200, "Annotation created successfully", annotation_fields) @api.response(200, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@marshal_with(annotation_fields)
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@get_app_model @account_initialization_required
@marshal_with(annotation_fields) @edit_permission_required
def post(self, app_model): def post(self, app_model):
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
if not current_user.has_edit_permission: .add_argument("message_id", required=False, type=uuid_value, location="json")
raise Forbidden() .add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
parser = reqparse.RequestParser() .add_argument("annotation_reply", required=False, type=dict, location="json")
parser.add_argument("message_id", required=False, type=uuid_value, location="json") )
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
@ -270,6 +269,7 @@ class MessageSuggestedQuestionApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model, message_id): def get(self, app_model, message_id):
current_user, _ = current_account_with_tenant()
message_id = str(message_id) message_id = str(message_id)
try: try:
@ -304,12 +304,12 @@ class MessageApi(Resource):
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@api.response(200, "Message retrieved successfully", message_detail_fields) @api.response(200, "Message retrieved successfully", message_detail_fields)
@api.response(404, "Message not found") @api.response(404, "Message not found")
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
@marshal_with(message_detail_fields) @marshal_with(message_detail_fields)
def get(self, app_model, message_id): def get(self, app_model, message_id: str):
message_id = str(message_id) message_id = str(message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()

View File

@ -30,8 +30,7 @@ class TraceAppConfigApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
parser.add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -63,9 +62,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, app_id): def post(self, app_id):
"""Create a new trace app configuration""" """Create a new trace app configuration"""
parser = reqparse.RequestParser() parser = (
parser.add_argument("tracing_provider", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("tracing_config", type=dict, required=True, location="json") .add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -99,9 +100,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def patch(self, app_id): def patch(self, app_id):
"""Update an existing trace app configuration""" """Update an existing trace app configuration"""
parser = reqparse.RequestParser() parser = (
parser.add_argument("tracing_provider", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("tracing_config", type=dict, required=True, location="json") .add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -129,8 +132,7 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def delete(self, app_id): def delete(self, app_id):
"""Delete an existing trace app configuration""" """Delete an existing trace app configuration"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
parser.add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
try: try:

View File

@ -9,29 +9,35 @@ from extensions.ext_database import db
from fields.app_fields import app_site_fields from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Account, Site from models import Site
def parse_app_site_args(): def parse_app_site_args():
parser = reqparse.RequestParser() parser = (
parser.add_argument("title", type=str, required=False, location="json") reqparse.RequestParser()
parser.add_argument("icon_type", type=str, required=False, location="json") .add_argument("title", type=str, required=False, location="json")
parser.add_argument("icon", type=str, required=False, location="json") .add_argument("icon_type", type=str, required=False, location="json")
parser.add_argument("icon_background", type=str, required=False, location="json") .add_argument("icon", type=str, required=False, location="json")
parser.add_argument("description", type=str, required=False, location="json") .add_argument("icon_background", type=str, required=False, location="json")
parser.add_argument("default_language", type=supported_language, required=False, location="json") .add_argument("description", type=str, required=False, location="json")
parser.add_argument("chat_color_theme", type=str, required=False, location="json") .add_argument("default_language", type=supported_language, required=False, location="json")
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") .add_argument("chat_color_theme", type=str, required=False, location="json")
parser.add_argument("customize_domain", type=str, required=False, location="json") .add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
parser.add_argument("copyright", type=str, required=False, location="json") .add_argument("customize_domain", type=str, required=False, location="json")
parser.add_argument("privacy_policy", type=str, required=False, location="json") .add_argument("copyright", type=str, required=False, location="json")
parser.add_argument("custom_disclaimer", type=str, required=False, location="json") .add_argument("privacy_policy", type=str, required=False, location="json")
parser.add_argument( .add_argument("custom_disclaimer", type=str, required=False, location="json")
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json" .add_argument(
"customize_token_strategy",
type=str,
choices=["must", "allow", "not_allow"],
required=False,
location="json",
)
.add_argument("prompt_public", type=bool, required=False, location="json")
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
) )
parser.add_argument("prompt_public", type=bool, required=False, location="json")
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
return parser.parse_args() return parser.parse_args()
@ -107,8 +113,6 @@ class AppSite(Resource):
if value is not None: if value is not None:
setattr(site, attr_name, value) setattr(site, attr_name, value)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = naive_utc_now() site.updated_at = naive_utc_now()
db.session.commit() db.session.commit()
@ -142,8 +146,6 @@ class AppSiteAccessTokenReset(Resource):
raise NotFound raise NotFound
site.code = Site.generate_code(16) site.code = Site.generate_code(16)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = naive_utc_now() site.updated_at = naive_utc_now()
db.session.commit() db.session.commit()

View File

@ -4,7 +4,6 @@ from decimal import Decimal
import pytz import pytz
import sqlalchemy as sa import sqlalchemy as sa
from flask import jsonify from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns from controllers.console import api, console_ns
@ -13,7 +12,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import AppMode, Message from models import AppMode, Message
@ -37,11 +36,13 @@ class DailyMessageStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -53,6 +54,7 @@ WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -109,13 +111,15 @@ class DailyConversationStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -175,11 +179,13 @@ class DailyTerminalsStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -191,7 +197,7 @@ WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -247,11 +253,13 @@ class DailyTokenCostStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -264,7 +272,7 @@ WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -322,11 +330,13 @@ class AverageSessionInteractionStatistic(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -346,7 +356,7 @@ FROM
c.app_id = :app_id c.app_id = :app_id
AND m.invoke_from != :invoke_from""" AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -413,11 +423,13 @@ class UserSatisfactionRateStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -433,7 +445,7 @@ WHERE
m.app_id = :app_id m.app_id = :app_id
AND m.invoke_from != :invoke_from""" AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -494,11 +506,13 @@ class AverageResponseTimeStatistic(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -510,7 +524,7 @@ WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -566,11 +580,13 @@ class TokensPerSecondStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -585,7 +601,7 @@ WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc

View File

@ -12,7 +12,7 @@ import services
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
@ -27,9 +27,8 @@ from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper from libs import helper
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models import App from models import App
from models.account import Account
from models.model import AppMode from models.model import AppMode
from models.workflow import Workflow from models.workflow import Workflow
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
@ -70,15 +69,11 @@ class DraftWorkflowApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get draft workflow Get draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model # fetch draft workflow by app_model
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model) workflow = workflow_service.get_draft_workflow(app_model=app_model)
@ -110,24 +105,24 @@ class DraftWorkflowApi(Resource):
@api.response(200, "Draft workflow synced successfully", workflow_fields) @api.response(200, "Draft workflow synced successfully", workflow_fields)
@api.response(400, "Invalid workflow configuration") @api.response(400, "Invalid workflow configuration")
@api.response(403, "Permission denied") @api.response(403, "Permission denied")
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Sync draft workflow Sync draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant()
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type: if "application/json" in content_type:
parser = reqparse.RequestParser() parser = (
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("features", type=dict, required=True, nullable=False, location="json") .add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("hash", type=str, required=False, location="json") .add_argument("features", type=dict, required=True, nullable=False, location="json")
parser.add_argument("environment_variables", type=list, required=True, location="json") .add_argument("hash", type=str, required=False, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json") .add_argument("environment_variables", type=list, required=True, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
@ -149,10 +144,6 @@ class DraftWorkflowApi(Resource):
return {"message": "Invalid JSON data"}, 400 return {"message": "Invalid JSON data"}, 400
else: else:
abort(415) abort(415)
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService() workflow_service = WorkflowService()
try: try:
@ -206,24 +197,21 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Run draft workflow Run draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant()
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
.add_argument("inputs", type=dict, location="json")
parser = reqparse.RequestParser() .add_argument("query", type=str, required=True, location="json", default="")
parser.add_argument("inputs", type=dict, location="json") .add_argument("files", type=list, location="json")
parser.add_argument("query", type=str, required=True, location="json", default="") .add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("files", type=list, location="json") .add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json") )
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -271,18 +259,13 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow iteration node Run draft workflow iteration node
""" """
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden() parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -323,18 +306,13 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow iteration node Run draft workflow iteration node
""" """
# The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -375,19 +353,13 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow loop node Run draft workflow loop node
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -428,19 +400,13 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow loop node Run draft workflow loop node
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -480,20 +446,17 @@ class DraftWorkflowRunApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Run draft workflow Run draft workflow
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
# The role of the current user in the ta table must be admin, owner, or editor .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
if not current_user.has_edit_permission: .add_argument("files", type=list, required=False, location="json")
raise Forbidden() )
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
@ -526,17 +489,11 @@ class WorkflowTaskStopApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, task_id: str): def post(self, app_model: App, task_id: str):
""" """
Stop workflow task Stop workflow task
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# Stop using both mechanisms for backward compatibility # Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check) # Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id) AppQueueManager.set_stop_flag_no_user_check(task_id)
@ -568,21 +525,18 @@ class DraftWorkflowNodeRunApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_fields) @marshal_with(workflow_run_node_execution_fields)
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow node Run draft workflow node
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
# The role of the current user in the ta table must be admin, owner, or editor .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
if not current_user.has_edit_permission: .add_argument("query", type=str, required=False, location="json", default="")
raise Forbidden() .add_argument("files", type=list, location="json", default=[])
)
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("query", type=str, required=False, location="json", default="")
parser.add_argument("files", type=list, location="json", default=[])
args = parser.parse_args() args = parser.parse_args()
user_inputs = args.get("inputs") user_inputs = args.get("inputs")
@ -622,17 +576,11 @@ class PublishedWorkflowApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get published workflow Get published workflow
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# fetch published workflow by app_model # fetch published workflow by app_model
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app_model) workflow = workflow_service.get_published_workflow(app_model=app_model)
@ -644,19 +592,17 @@ class PublishedWorkflowApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Publish workflow Publish workflow
""" """
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
# The role of the current user in the ta table must be admin, owner, or editor reqparse.RequestParser()
if not current_user.has_edit_permission: .add_argument("marked_name", type=str, required=False, default="", location="json")
raise Forbidden() .add_argument("marked_comment", type=str, required=False, default="", location="json")
)
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
args = parser.parse_args() args = parser.parse_args()
# Validate name and comment length # Validate name and comment length
@ -702,17 +648,11 @@ class DefaultBlockConfigsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get default block config Get default block config
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs # Get default block configs
workflow_service = WorkflowService() workflow_service = WorkflowService()
return workflow_service.get_default_block_configs() return workflow_service.get_default_block_configs()
@ -729,18 +669,12 @@ class DefaultBlockConfigApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App, block_type: str): def get(self, app_model: App, block_type: str):
""" """
Get default block config Get default block config
""" """
if not isinstance(current_user, Account): parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
q = args.get("q") q = args.get("q")
@ -769,24 +703,23 @@ class ConvertToWorkflowApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION]) @get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Convert basic mode of chatbot app to workflow mode Convert basic mode of chatbot app to workflow mode
Convert expert mode of chatbot app to workflow mode Convert expert mode of chatbot app to workflow mode
Convert Completion App to Workflow App Convert Completion App to Workflow App
""" """
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
if request.data: if request.data:
parser = reqparse.RequestParser() parser = (
parser.add_argument("name", type=str, required=False, nullable=True, location="json") reqparse.RequestParser()
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json") .add_argument("name", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon", type=str, required=False, nullable=True, location="json") .add_argument("icon_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") .add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
else: else:
args = {} args = {}
@ -812,21 +745,20 @@ class PublishedAllWorkflowApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_pagination_fields) @marshal_with(workflow_pagination_fields)
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get published workflows Get published workflows
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
if not current_user.has_edit_permission: .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
raise Forbidden() .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument("user_id", type=str, required=False, location="args")
parser = reqparse.RequestParser() .add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") )
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("user_id", type=str, required=False, location="args")
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
args = parser.parse_args() args = parser.parse_args()
page = int(args.get("page", 1)) page = int(args.get("page", 1))
limit = int(args.get("limit", 10)) limit = int(args.get("limit", 10))
@ -879,19 +811,17 @@ class WorkflowByIdApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
@edit_permission_required
def patch(self, app_model: App, workflow_id: str): def patch(self, app_model: App, workflow_id: str):
""" """
Update workflow attributes Update workflow attributes
""" """
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
# Check permission reqparse.RequestParser()
if not current_user.has_edit_permission: .add_argument("marked_name", type=str, required=False, location="json")
raise Forbidden() .add_argument("marked_comment", type=str, required=False, location="json")
)
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, location="json")
parser.add_argument("marked_comment", type=str, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
# Validate name and comment length # Validate name and comment length
@ -934,16 +864,11 @@ class WorkflowByIdApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def delete(self, app_model: App, workflow_id: str): def delete(self, app_model: App, workflow_id: str):
""" """
Delete workflow Delete workflow
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission
if not current_user.has_edit_permission:
raise Forbidden()
workflow_service = WorkflowService() workflow_service = WorkflowService()
# Create a session and manage the transaction # Create a session and manage the transaction

View File

@ -42,33 +42,35 @@ class WorkflowAppLogApi(Resource):
""" """
Get workflow app logs Get workflow app logs
""" """
parser = reqparse.RequestParser() parser = (
parser.add_argument("keyword", type=str, location="args") reqparse.RequestParser()
parser.add_argument( .add_argument("keyword", type=str, location="args")
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args" .add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
) )
parser.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
parser.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
parser.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
args.status = WorkflowExecutionStatus(args.status) if args.status else None args.status = WorkflowExecutionStatus(args.status) if args.status else None

View File

@ -22,8 +22,7 @@ from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models import App, AppMode from models import Account, App, AppMode
from models.account import Account
from models.workflow import WorkflowDraftVariable from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@ -58,16 +57,18 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
def _create_pagination_parser(): def _create_pagination_parser():
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"page", .add_argument(
type=inputs.int_range(1, 100_000), "page",
required=False, type=inputs.int_range(1, 100_000),
default=1, required=False,
location="args", default=1,
help="the page of data requested", location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
) )
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser return parser
@ -320,10 +321,11 @@ class VariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# } # }
parser = reqparse.RequestParser() parser = (
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") reqparse.RequestParser()
# Parse 'value' field as-is to maintain its original data structure .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=db.session(), session=db.session(),

View File

@ -1,6 +1,5 @@
from typing import cast from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
@ -9,15 +8,81 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.workflow_run_fields import ( from fields.workflow_run_fields import (
advanced_chat_workflow_run_pagination_fields, advanced_chat_workflow_run_pagination_fields,
workflow_run_count_fields,
workflow_run_detail_fields, workflow_run_detail_fields,
workflow_run_node_execution_list_fields, workflow_run_node_execution_list_fields,
workflow_run_pagination_fields, workflow_run_pagination_fields,
) )
from libs.custom_inputs import time_duration
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import login_required from libs.login import current_user, login_required
from models import Account, App, AppMode, EndUser from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
from services.workflow_run_service import WorkflowRunService from services.workflow_run_service import WorkflowRunService
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
def _parse_workflow_run_list_args():
"""
Parse common arguments for workflow run list endpoints.
Returns:
Parsed arguments containing last_id, limit, status, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
return parser.parse_args()
def _parse_workflow_run_count_args():
"""
Parse common arguments for workflow run count endpoints.
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
class AdvancedChatAppWorkflowRunListApi(Resource): class AdvancedChatAppWorkflowRunListApi(Resource):
@ -25,6 +90,8 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@api.doc(description="Get advanced chat workflow run list") @api.doc(description="Get advanced chat workflow run list")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields) @api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@ -35,13 +102,64 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
""" """
Get advanced chat app workflow run list Get advanced chat app workflow run list
""" """
parser = reqparse.RequestParser() args = _parse_workflow_run_list_args()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") # Default to DEBUGGING if not specified
args = parser.parse_args() triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService() workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args) result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
app_model=app_model, args=args, triggered_from=triggered_from
)
return result
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
class AdvancedChatAppWorkflowRunCountApi(Resource):
@api.doc("get_advanced_chat_workflow_runs_count")
@api.doc(description="Get advanced chat workflow runs count statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
)
}
)
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@marshal_with(workflow_run_count_fields)
def get(self, app_model: App):
"""
Get advanced chat workflow runs count statistics
"""
args = _parse_workflow_run_count_args()
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_runs_count(
app_model=app_model,
status=args.get("status"),
time_range=args.get("time_range"),
triggered_from=triggered_from,
)
return result return result
@ -52,6 +170,8 @@ class WorkflowRunListApi(Resource):
@api.doc(description="Get workflow run list") @api.doc(description="Get workflow run list")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields) @api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@ -62,13 +182,64 @@ class WorkflowRunListApi(Resource):
""" """
Get workflow run list Get workflow run list
""" """
parser = reqparse.RequestParser() args = _parse_workflow_run_list_args()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") # Default to DEBUGGING for workflow if not specified (backward compatibility)
args = parser.parse_args() triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService() workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args) result = workflow_run_service.get_paginate_workflow_runs(
app_model=app_model, args=args, triggered_from=triggered_from
)
return result
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
class WorkflowRunCountApi(Resource):
@api.doc("get_workflow_runs_count")
@api.doc(description="Get workflow runs count statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
)
}
)
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_count_fields)
def get(self, app_model: App):
"""
Get workflow runs count statistics
"""
args = _parse_workflow_run_count_args()
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_runs_count(
app_model=app_model,
status=args.get("status"),
time_range=args.get("time_range"),
triggered_from=triggered_from,
)
return result return result

View File

@ -4,7 +4,6 @@ from decimal import Decimal
import pytz import pytz
import sqlalchemy as sa import sqlalchemy as sa
from flask import jsonify from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from controllers.console import api, console_ns from controllers.console import api, console_ns
@ -12,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode from models.model import AppMode
@ -29,11 +28,13 @@ class WorkflowDailyRunsStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -49,7 +50,7 @@ WHERE
"app_id": app_model.id, "app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
} }
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -97,11 +98,13 @@ class WorkflowDailyTerminalsStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -117,7 +120,7 @@ WHERE
"app_id": app_model.id, "app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
} }
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -165,11 +168,13 @@ class WorkflowDailyTokenCostStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -185,7 +190,7 @@ WHERE
"app_id": app_model.id, "app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
} }
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -238,11 +243,13 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.WORKFLOW])
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -271,7 +278,7 @@ GROUP BY
"app_id": app_model.id, "app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
} }
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc

View File

@ -4,28 +4,29 @@ from typing import ParamSpec, TypeVar, Union
from controllers.console.app.error import AppNotFoundError from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user from libs.login import current_account_with_tenant
from models import App, AppMode from models import App, AppMode
from models.account import Account
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
P1 = ParamSpec("P1")
R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None: def _load_app_model(app_id: str) -> App | None:
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
app_model = ( app_model = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
return app_model return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None): def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]): def decorator(view_func: Callable[P1, R1]):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs): def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
if not kwargs.get("app_id"): if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters") raise ValueError("missing app_id in path parameters")

View File

@ -7,18 +7,14 @@ from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus from models import AccountStatus
from services.account_service import AccountService, RegisterService from services.account_service import AccountService, RegisterService
active_check_parser = reqparse.RequestParser() active_check_parser = (
active_check_parser.add_argument( reqparse.RequestParser()
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID" .add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
) .add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
active_check_parser.add_argument( .add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
"email", type=email, required=False, nullable=True, location="args", help="Email address"
)
active_check_parser.add_argument(
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
) )
@ -60,15 +56,15 @@ class ActivateCheckApi(Resource):
return {"is_valid": False} return {"is_valid": False}
active_parser = reqparse.RequestParser() active_parser = (
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") reqparse.RequestParser()
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json") .add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("email", type=email, required=False, nullable=True, location="json")
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") .add_argument("token", type=str, required=True, nullable=False, location="json")
active_parser.add_argument( .add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
"interface_language", type=supported_language, required=True, nullable=False, location="json" .add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
) )
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
@console_ns.route("/activate") @console_ns.route("/activate")

View File

@ -45,10 +45,12 @@ class ApiKeyAuthDataSourceBinding(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("category", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="json") .add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") .add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args) ApiKeyAuthService.validate_api_key_auth_args(args)
try: try:

View File

@ -2,13 +2,12 @@ import logging
import httpx import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api, console_ns from controllers.console import api, console_ns
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required from ..wraps import account_initialization_required, setup_required
@ -45,6 +44,7 @@ class OAuthDataSource(Resource):
@api.response(403, "Admin privileges required") @api.response(403, "Admin privileges required")
def get(self, provider: str): def get(self, provider: str):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()

View File

@ -19,7 +19,7 @@ from controllers.console.wraps import email_password_login_enabled, email_regist
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.account import Account from models import Account
from services.account_service import AccountService from services.account_service import AccountService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.account import AccountNotFoundError, AccountRegisterError
@ -31,9 +31,11 @@ class EmailRegisterSendEmailApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
@ -59,10 +61,12 @@ class EmailRegisterCheckApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") .add_argument("email", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
@ -100,10 +104,12 @@ class EmailRegisterResetApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("token", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") .add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# Validate passwords match # Validate passwords match

View File

@ -20,7 +20,7 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models.account import Account from models import Account
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -54,9 +54,11 @@ class ForgotPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
@ -111,10 +113,12 @@ class ForgotPasswordCheckApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") .add_argument("email", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
@ -169,10 +173,12 @@ class ForgotPasswordResetApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("token", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") .add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# Validate passwords match # Validate passwords match

View File

@ -1,7 +1,5 @@
from typing import cast
import flask_login import flask_login
from flask import request from flask import make_response, request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
import services import services
@ -26,7 +24,17 @@ from controllers.console.error import (
from controllers.console.wraps import email_password_login_enabled, setup_required from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
from models.account import Account from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie,
extract_access_token,
extract_csrf_token,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import AccountRegisterError from services.errors.account import AccountRegisterError
@ -42,11 +50,13 @@ class LoginApi(Resource):
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
"""Authenticate user and login.""" """Authenticate user and login."""
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("password", type=str, required=True, location="json") .add_argument("email", type=email, required=True, location="json")
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") .add_argument("password", type=str, required=True, location="json")
parser.add_argument("invite_token", type=str, required=False, default=None, location="json") .add_argument("remember_me", type=bool, required=False, default=False, location="json")
.add_argument("invite_token", type=str, required=False, default=None, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
@ -89,19 +99,36 @@ class LoginApi(Resource):
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
@console_ns.route("/logout") @console_ns.route("/logout")
class LogoutApi(Resource): class LogoutApi(Resource):
@setup_required @setup_required
def get(self): def post(self):
account = cast(Account, flask_login.current_user) current_user, _ = current_account_with_tenant()
account = current_user
if isinstance(account, flask_login.AnonymousUserMixin): if isinstance(account, flask_login.AnonymousUserMixin):
return {"result": "success"} response = make_response({"result": "success"})
AccountService.logout(account=account) else:
flask_login.logout_user() AccountService.logout(account=account)
return {"result": "success"} flask_login.logout_user()
response = make_response({"result": "success"})
# Clear cookies on logout
clear_access_token_from_cookie(response)
clear_refresh_token_from_cookie(response)
clear_csrf_token_from_cookie(response)
return response
@console_ns.route("/reset-password") @console_ns.route("/reset-password")
@ -109,9 +136,11 @@ class ResetPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if args["language"] is not None and args["language"] == "zh-Hans": if args["language"] is not None and args["language"] == "zh-Hans":
@ -137,9 +166,11 @@ class ResetPasswordSendEmailApi(Resource):
class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginSendEmailApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
@ -170,10 +201,12 @@ class EmailCodeLoginSendEmailApi(Resource):
class EmailCodeLoginApi(Resource): class EmailCodeLoginApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") .add_argument("email", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
@ -220,18 +253,46 @@ class EmailCodeLoginApi(Resource):
raise WorkspacesLimitExceeded() raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
# Set HTTP-only secure cookies for tokens
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
return response
@console_ns.route("/refresh-token") @console_ns.route("/refresh-token")
class RefreshTokenApi(Resource): class RefreshTokenApi(Resource):
def post(self): def post(self):
parser = reqparse.RequestParser() # Get refresh token from cookie instead of request body
parser.add_argument("refresh_token", type=str, required=True, location="json") refresh_token = request.cookies.get("refresh_token")
args = parser.parse_args()
if not refresh_token:
return {"result": "fail", "message": "No refresh token provided"}, 401
try: try:
new_token_pair = AccountService.refresh_token(args["refresh_token"]) new_token_pair = AccountService.refresh_token(refresh_token)
return {"result": "success", "data": new_token_pair.model_dump()}
# Create response with new cookies
response = make_response({"result": "success"})
# Update cookies with new tokens
set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token)
set_access_token_to_cookie(request, response, new_token_pair.access_token)
set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token)
return response
except Exception as e: except Exception as e:
return {"result": "fail", "data": str(e)}, 401 return {"result": "fail", "message": str(e)}, 401
# this api helps frontend to check whether user is authenticated
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
@console_ns.route("/login/status")
class LoginStatus(Resource):
def get(self):
token = extract_access_token(request)
csrf_token = extract_csrf_token(request)
return {"logged_in": bool(token) and bool(csrf_token)}

View File

@ -14,8 +14,12 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account from libs.token import (
from models.account import AccountStatus set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from models import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.account import AccountNotFoundError, AccountRegisterError
@ -153,9 +157,12 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request), ip_address=extract_remote_ip(request),
) )
return redirect( response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
) set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None: def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:

View File

@ -1,16 +1,15 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast from typing import Concatenate, ParamSpec, TypeVar
import flask_login
from flask import jsonify, request from flask import jsonify, request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account from models import Account
from models.model import OAuthProviderApp from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
@ -24,8 +23,7 @@ T = TypeVar("T")
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view) @wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs): def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
client_id = parsed_args.get("client_id") client_id = parsed_args.get("client_id")
if not client_id: if not client_id:
@ -91,8 +89,7 @@ class OAuthServerAppApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
parser.add_argument("redirect_uri", type=str, required=True, location="json")
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
redirect_uri = parsed_args.get("redirect_uri") redirect_uri = parsed_args.get("redirect_uri")
@ -116,7 +113,8 @@ class OAuthServerUserAuthorizeApi(Resource):
@account_initialization_required @account_initialization_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
account = cast(Account, flask_login.current_user) current_user, _ = current_account_with_tenant()
account = current_user
user_account_id = account.id user_account_id = account.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id) code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
@ -132,12 +130,14 @@ class OAuthServerUserTokenApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser() parser = (
parser.add_argument("grant_type", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=False, location="json") .add_argument("grant_type", type=str, required=True, location="json")
parser.add_argument("client_secret", type=str, required=False, location="json") .add_argument("code", type=str, required=False, location="json")
parser.add_argument("redirect_uri", type=str, required=False, location="json") .add_argument("client_secret", type=str, required=False, location="json")
parser.add_argument("refresh_token", type=str, required=False, location="json") .add_argument("redirect_uri", type=str, required=False, location="json")
.add_argument("refresh_token", type=str, required=False, location="json")
)
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
try: try:

View File

@ -2,8 +2,7 @@ from flask_restx import Resource, reqparse
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.model import Account
from services.billing_service import BillingService from services.billing_service import BillingService
@ -14,17 +13,15 @@ class Subscription(Resource):
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
parser = reqparse.RequestParser() current_user, current_tenant_id = current_account_with_tenant()
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) parser = (
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) reqparse.RequestParser()
args = parser.parse_args() .add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
assert isinstance(current_user, Account) .add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_subscription(
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
) )
args = parser.parse_args()
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
@console_ns.route("/billing/invoices") @console_ns.route("/billing/invoices")
@ -34,7 +31,6 @@ class Invoices(Resource):
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None return BillingService.get_invoices(current_user.email, current_tenant_id)
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)

View File

@ -2,8 +2,7 @@ from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from services.billing_service import BillingService from services.billing_service import BillingService
from .. import console_ns from .. import console_ns
@ -17,19 +16,16 @@ class ComplianceApi(Resource):
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
parser = reqparse.RequestParser()
parser.add_argument("doc_name", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device") device_info = request.headers.get("User-Agent", "Unknown device")
return BillingService.get_compliance_download_link( return BillingService.get_compliance_download_link(
doc_name=args.doc_name, doc_name=args.doc_name,
account_id=current_user.id, account_id=current_user.id,
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
ip=ip_address, ip=ip_address,
device_info=device_info, device_info=device_info,
) )

View File

@ -246,12 +246,12 @@ class DataSourceNotionApi(Resource):
def post(self): def post(self):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") reqparse.RequestParser()
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") .add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") .add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
parser.add_argument( .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
"doc_language", type=str, default="English", required=False, nullable=False, location="json" .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
) )
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args

View File

@ -1,7 +1,6 @@
from typing import Any, cast from typing import Any, cast
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -30,10 +29,9 @@ from extensions.ext_database import db
from fields.app_fields import related_app_list from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields from fields.document_fields import document_status_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -138,6 +136,7 @@ class DatasetListApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def get(self): def get(self):
current_user, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids") ids = request.args.getlist("ids")
@ -146,15 +145,15 @@ class DatasetListApi(Resource):
tag_ids = request.args.getlist("tag_ids") tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true" include_all = request.args.get("include_all", default="false").lower() == "true"
if ids: if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
else: else:
datasets, total = DatasetService.get_datasets( datasets, total = DatasetService.get_datasets(
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all page, limit, current_tenant_id, current_user, search, tag_ids, include_all
) )
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@ -207,50 +206,53 @@ class DatasetListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="type is required. Name must be between 1 to 40 characters.", required=True,
type=_validate_name, help="type is required. Name must be between 1 to 40 characters.",
) type=_validate_name,
parser.add_argument( )
"description", .add_argument(
type=validate_description_length, "description",
nullable=True, type=validate_description_length,
required=False, nullable=True,
default="", required=False,
) default="",
parser.add_argument( )
"indexing_technique", .add_argument(
type=str, "indexing_technique",
location="json", type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST, location="json",
nullable=True, choices=Dataset.INDEXING_TECHNIQUE_LIST,
help="Invalid indexing technique.", nullable=True,
) help="Invalid indexing technique.",
parser.add_argument( )
"external_knowledge_api_id", .add_argument(
type=str, "external_knowledge_api_id",
nullable=True, type=str,
required=False, nullable=True,
) required=False,
parser.add_argument( )
"provider", .add_argument(
type=str, "provider",
nullable=True, type=str,
choices=Dataset.PROVIDER_LIST, nullable=True,
required=False, choices=Dataset.PROVIDER_LIST,
default="vendor", required=False,
) default="vendor",
parser.add_argument( )
"external_knowledge_id", .add_argument(
type=str, "external_knowledge_id",
nullable=True, type=str,
required=False, nullable=True,
required=False,
)
) )
args = parser.parse_args() args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@ -258,11 +260,11 @@ class DatasetListApi(Resource):
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
name=args["name"], name=args["name"],
description=args["description"], description=args["description"],
indexing_technique=args["indexing_technique"], indexing_technique=args["indexing_technique"],
account=cast(Account, current_user), account=current_user,
permission=DatasetPermissionEnum.ONLY_ME, permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"], provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"], external_knowledge_api_id=args["external_knowledge_api_id"],
@ -286,6 +288,7 @@ class DatasetApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -305,7 +308,7 @@ class DatasetApi(Resource):
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@ -351,73 +354,76 @@ class DatasetApi(Resource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
help="type is required. Name must be between 1 to 40 characters.", nullable=False,
type=_validate_name, help="type is required. Name must be between 1 to 40 characters.",
) type=_validate_name,
parser.add_argument("description", location="json", store_missing=False, type=validate_description_length) )
parser.add_argument( .add_argument("description", location="json", store_missing=False, type=validate_description_length)
"indexing_technique", .add_argument(
type=str, "indexing_technique",
location="json", type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST, location="json",
nullable=True, choices=Dataset.INDEXING_TECHNIQUE_LIST,
help="Invalid indexing technique.", nullable=True,
) help="Invalid indexing technique.",
parser.add_argument( )
"permission", .add_argument(
type=str, "permission",
location="json", type=str,
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), location="json",
help="Invalid permission.", choices=(
) DatasetPermissionEnum.ONLY_ME,
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") DatasetPermissionEnum.ALL_TEAM,
parser.add_argument( DatasetPermissionEnum.PARTIAL_TEAM,
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." ),
) help="Invalid permission.",
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") )
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") .add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument(
parser.add_argument( "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
"external_retrieval_model", )
type=dict, .add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
required=False, .add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
nullable=True, .add_argument(
location="json", "external_retrieval_model",
help="Invalid external retrieval model.", type=dict,
) required=False,
nullable=True,
parser.add_argument( location="json",
"external_knowledge_id", help="Invalid external retrieval model.",
type=str, )
required=False, .add_argument(
nullable=True, "external_knowledge_id",
location="json", type=str,
help="Invalid external knowledge id.", required=False,
) nullable=True,
location="json",
parser.add_argument( help="Invalid external knowledge id.",
"external_knowledge_api_id", )
type=str, .add_argument(
required=False, "external_knowledge_api_id",
nullable=True, type=str,
location="json", required=False,
help="Invalid external knowledge api id.", nullable=True,
) location="json",
help="Invalid external knowledge api id.",
parser.add_argument( )
"icon_info", .add_argument(
type=dict, "icon_info",
required=False, type=dict,
nullable=True, required=False,
location="json", nullable=True,
help="Invalid icon info.", location="json",
help="Invalid icon info.",
)
) )
args = parser.parse_args() args = parser.parse_args()
data = request.get_json() data = request.get_json()
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting # check embedding model setting
if ( if (
@ -440,7 +446,7 @@ class DatasetApi(Resource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_user.current_tenant_id tenant_id = current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members": if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list( DatasetPermissionService.update_partial_member_list(
@ -464,9 +470,9 @@ class DatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id): def delete(self, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor if not (current_user.has_edit_permission or current_user.is_dataset_operator):
if not (current_user.is_editor or current_user.is_dataset_operator):
raise Forbidden() raise Forbidden()
try: try:
@ -505,6 +511,7 @@ class DatasetQueryApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -539,32 +546,31 @@ class DatasetIndexingEstimateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json") reqparse.RequestParser()
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") .add_argument("info_list", type=dict, required=True, nullable=True, location="json")
parser.add_argument( .add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
"indexing_technique", .add_argument(
type=str, "indexing_technique",
required=True, type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST, required=True,
nullable=True, choices=Dataset.INDEXING_TECHNIQUE_LIST,
location="json", nullable=True,
) location="json",
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") )
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json") .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument( .add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
"doc_language", type=str, default="English", required=False, nullable=False, location="json" .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
) )
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
extract_settings = [] extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file": if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"] file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = db.session.scalars( file_details = db.session.scalars(
select(UploadFile).where( select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
)
).all() ).all()
if file_details is None: if file_details is None:
@ -592,7 +598,7 @@ class DatasetIndexingEstimateApi(Resource):
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"], "notion_obj_id": page["page_id"],
"notion_page_type": page["type"], "notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id, "tenant_id": current_tenant_id,
} }
), ),
document_model=args["doc_form"], document_model=args["doc_form"],
@ -608,7 +614,7 @@ class DatasetIndexingEstimateApi(Resource):
"provider": website_info_list["provider"], "provider": website_info_list["provider"],
"job_id": website_info_list["job_id"], "job_id": website_info_list["job_id"],
"url": url, "url": url,
"tenant_id": current_user.current_tenant_id, "tenant_id": current_tenant_id,
"mode": "crawl", "mode": "crawl",
"only_main_content": website_info_list["only_main_content"], "only_main_content": website_info_list["only_main_content"],
} }
@ -621,7 +627,7 @@ class DatasetIndexingEstimateApi(Resource):
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.indexing_estimate( response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_tenant_id,
extract_settings, extract_settings,
args["process_rule"], args["process_rule"],
args["doc_form"], args["doc_form"],
@ -652,6 +658,7 @@ class DatasetRelatedAppListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(related_app_list) @marshal_with(related_app_list)
def get(self, dataset_id): def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -683,11 +690,10 @@ class DatasetIndexingStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
documents = db.session.scalars( documents = db.session.scalars(
select(Document).where( select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
)
).all() ).all()
documents_status = [] documents_status = []
for document in documents: for document in documents:
@ -739,10 +745,9 @@ class DatasetApiKeyApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(api_key_list) @marshal_with(api_key_list)
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars( keys = db.session.scalars(
select(ApiToken).where( select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
)
).all() ).all()
return {"items": keys} return {"items": keys}
@ -752,12 +757,13 @@ class DatasetApiKeyApi(Resource):
@marshal_with(api_key_fields) @marshal_with(api_key_fields)
def post(self): def post(self):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
.count() .count()
) )
@ -770,7 +776,7 @@ class DatasetApiKeyApi(Resource):
key = ApiToken.generate_api_key(self.token_prefix, 24) key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken() api_token = ApiToken()
api_token.tenant_id = current_user.current_tenant_id api_token.tenant_id = current_tenant_id
api_token.token = key api_token.token = key
api_token.type = self.resource_type api_token.type = self.resource_type
db.session.add(api_token) db.session.add(api_token)
@ -790,6 +796,7 @@ class DatasetApiDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, api_key_id): def delete(self, api_key_id):
current_user, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id) api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
@ -799,7 +806,7 @@ class DatasetApiDeleteApi(Resource):
key = ( key = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where( .where(
ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id, ApiToken.id == api_key_id,
) )
@ -898,6 +905,7 @@ class DatasetPermissionUserListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:

View File

@ -6,7 +6,6 @@ from typing import Literal, cast
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc, select from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -53,9 +52,8 @@ from fields.document_fields import (
document_with_segments_fields, document_with_segments_fields,
) )
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DocumentPipelineExecutionLog from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
@ -65,6 +63,7 @@ logger = logging.getLogger(__name__)
class DocumentResource(Resource): class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document: def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -79,12 +78,13 @@ class DocumentResource(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id: if document.tenant_id != current_tenant_id:
raise Forbidden("No permission.") raise Forbidden("No permission.")
return document return document
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]: def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
current_user, _ = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -112,6 +112,7 @@ class GetProcessRuleApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
current_user, _ = current_account_with_tenant()
req_data = request.args req_data = request.args
document_id = req_data.get("document_id") document_id = req_data.get("document_id")
@ -168,6 +169,7 @@ class DatasetDocumentListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
@ -199,7 +201,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
if search: if search:
search = f"%{search}%" search = f"%{search}%"
@ -273,6 +275,7 @@ class DatasetDocumentListApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -289,20 +292,20 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" .add_argument(
) "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
parser.add_argument("data_source", type=dict, required=False, location="json") )
parser.add_argument("process_rule", type=dict, required=False, location="json") .add_argument("data_source", type=dict, required=False, location="json")
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") .add_argument("process_rule", type=dict, required=False, location="json")
parser.add_argument("original_document_id", type=str, required=False, location="json") .add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") .add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") .add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument( .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
"doc_language", type=str, default="English", required=False, nullable=False, location="json" .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
) )
args = parser.parse_args() args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args) knowledge_config = KnowledgeConfig.model_validate(args)
@ -372,27 +375,28 @@ class DatasetInitApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"indexing_technique", .add_argument(
type=str, "indexing_technique",
choices=Dataset.INDEXING_TECHNIQUE_LIST, type=str,
required=True, choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=False, required=True,
location="json", nullable=False,
location="json",
)
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
) )
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args) knowledge_config = KnowledgeConfig.model_validate(args)
@ -402,7 +406,7 @@ class DatasetInitApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=args["embedding_model_provider"], provider=args["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=args["embedding_model"], model=args["embedding_model"],
@ -419,9 +423,9 @@ class DatasetInitApi(Resource):
try: try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id( dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
knowledge_config=knowledge_config, knowledge_config=knowledge_config,
account=cast(Account, current_user), account=current_user,
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -447,6 +451,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id, document_id): def get(self, dataset_id, document_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
@ -482,7 +487,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
try: try:
estimate_response = indexing_runner.indexing_estimate( estimate_response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_tenant_id,
[extract_setting], [extract_setting],
data_process_rule_dict, data_process_rule_dict,
document.doc_form, document.doc_form,
@ -511,6 +516,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id, batch): def get(self, dataset_id, batch):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
batch = str(batch) batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch) documents = self.get_batch_documents(dataset_id, batch)
@ -530,7 +536,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first() .first()
) )
@ -553,7 +559,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
"notion_workspace_id": data_source_info["notion_workspace_id"], "notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"], "notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"], "notion_page_type": data_source_info["type"],
"tenant_id": current_user.current_tenant_id, "tenant_id": current_tenant_id,
} }
), ),
document_model=document.doc_form, document_model=document.doc_form,
@ -569,7 +575,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
"provider": data_source_info["provider"], "provider": data_source_info["provider"],
"job_id": data_source_info["job_id"], "job_id": data_source_info["job_id"],
"url": data_source_info["url"], "url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id, "tenant_id": current_tenant_id,
"mode": data_source_info["mode"], "mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"], "only_main_content": data_source_info["only_main_content"],
} }
@ -583,7 +589,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.indexing_estimate( response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_tenant_id,
extract_settings, extract_settings,
data_process_rule_dict, data_process_rule_dict,
document.doc_form, document.doc_form,
@ -834,6 +840,7 @@ class DocumentProcessingApi(DocumentResource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]): def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
@ -884,6 +891,7 @@ class DocumentMetadataApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self, dataset_id, document_id): def put(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
@ -931,6 +939,7 @@ class DocumentStatusApi(DocumentResource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if dataset is None: if dataset is None:
@ -1034,8 +1043,9 @@ class DocumentRetryApi(DocumentResource):
def post(self, dataset_id): def post(self, dataset_id):
"""retry document.""" """retry document."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json") "document_ids", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -1077,14 +1087,14 @@ class DocumentRenameApi(DocumentResource):
@marshal_with(document_fields) @marshal_with(document_fields)
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset) DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -1102,6 +1112,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
@account_initialization_required @account_initialization_required
def get(self, dataset_id, document_id): def get(self, dataset_id, document_id):
"""sync website document.""" """sync website document."""
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
@ -1110,7 +1121,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
document = DocumentService.get_document(dataset.id, document_id) document = DocumentService.get_document(dataset.id, document_id)
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id: if document.tenant_id != current_tenant_id:
raise Forbidden("No permission.") raise Forbidden("No permission.")
if document.data_source_type != "website_crawl": if document.data_source_type != "website_crawl":
raise ValueError("Document is not a website document.") raise ValueError("Document is not a website document.")

View File

@ -60,13 +60,15 @@ class DatasetDocumentSegmentListApi(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = reqparse.RequestParser() parser = (
parser.add_argument("limit", type=int, default=20, location="args") reqparse.RequestParser()
parser.add_argument("status", type=str, action="append", default=[], location="args") .add_argument("limit", type=int, default=20, location="args")
parser.add_argument("hit_count_gte", type=int, default=None, location="args") .add_argument("status", type=str, action="append", default=[], location="args")
parser.add_argument("enabled", type=str, default="all", location="args") .add_argument("hit_count_gte", type=int, default=None, location="args")
parser.add_argument("keyword", type=str, default=None, location="args") .add_argument("enabled", type=str, default="all", location="args")
parser.add_argument("page", type=int, default=1, location="args") .add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
args = parser.parse_args() args = parser.parse_args()
@ -244,10 +246,12 @@ class DatasetDocumentSegmentAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = (
parser.add_argument("content", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("answer", type=str, required=False, nullable=True, location="json") .add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") .add_argument("answer", type=str, required=False, nullable=True, location="json")
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset) segment = SegmentService.create_segment(args, document, dataset)
@ -309,12 +313,14 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = (
parser.add_argument("content", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("answer", type=str, required=False, nullable=True, location="json") .add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") .add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument( .add_argument("keywords", type=list, required=False, nullable=True, location="json")
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json" .add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
) )
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
@ -385,8 +391,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json") "upload_file_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
upload_file_id = args["upload_file_id"] upload_file_id = args["upload_file_id"]
@ -484,8 +491,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("content", type=str, required=True, nullable=False, location="json") "content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
content = args["content"] content = args["content"]
@ -521,10 +529,12 @@ class ChildChunkAddApi(Resource):
) )
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
parser = reqparse.RequestParser() parser = (
parser.add_argument("limit", type=int, default=20, location="args") reqparse.RequestParser()
parser.add_argument("keyword", type=str, default=None, location="args") .add_argument("limit", type=int, default=20, location="args")
parser.add_argument("page", type=int, default=1, location="args") .add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
args = parser.parse_args() args = parser.parse_args()
@ -578,8 +588,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") "chunks", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
chunks_data = args["chunks"] chunks_data = args["chunks"]
@ -700,8 +711,9 @@ class ChildChunkUpdateApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("content", type=str, required=True, nullable=False, location="json") "content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
content = args["content"] content = args["content"]

View File

@ -1,7 +1,4 @@
from typing import cast
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, reqparse from flask_restx import Resource, fields, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -10,8 +7,7 @@ from controllers.console import api, console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
@ -40,12 +36,13 @@ class ExternalApiTemplateListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str) search = request.args.get("keyword", default=None, type=str)
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis( external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
page, limit, current_user.current_tenant_id, search page, limit, current_tenant_id, search
) )
response = { response = {
"data": [item.to_dict() for item in external_knowledge_apis], "data": [item.to_dict() for item in external_knowledge_apis],
@ -60,20 +57,23 @@ class ExternalApiTemplateListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() current_user, current_tenant_id = current_account_with_tenant()
parser.add_argument( parser = (
"name", reqparse.RequestParser()
nullable=False, .add_argument(
required=True, "name",
help="Name is required. Name must be between 1 to 100 characters.", nullable=False,
type=_validate_name, required=True,
) help="Name is required. Name must be between 1 to 100 characters.",
parser.add_argument( type=_validate_name,
"settings", )
type=dict, .add_argument(
location="json", "settings",
nullable=False, type=dict,
required=True, location="json",
nullable=False,
required=True,
)
) )
args = parser.parse_args() args = parser.parse_args()
@ -85,7 +85,7 @@ class ExternalApiTemplateListApi(Resource):
try: try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api( external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args tenant_id=current_tenant_id, user_id=current_user.id, args=args
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@ -115,28 +115,31 @@ class ExternalApiTemplateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, external_knowledge_api_id): def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_id = str(external_knowledge_api_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="type is required. Name must be between 1 to 100 characters.", required=True,
type=_validate_name, help="type is required. Name must be between 1 to 100 characters.",
) type=_validate_name,
parser.add_argument( )
"settings", .add_argument(
type=dict, "settings",
location="json", type=dict,
nullable=False, location="json",
required=True, nullable=False,
required=True,
)
) )
args = parser.parse_args() args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"]) ExternalDatasetService.validate_api_list(args["settings"])
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api( external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
user_id=current_user.id, user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id, external_knowledge_api_id=external_knowledge_api_id,
args=args, args=args,
@ -148,13 +151,13 @@ class ExternalApiTemplateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, external_knowledge_api_id): def delete(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_id = str(external_knowledge_api_id)
# The role of the current user in the ta table must be admin, owner, or editor if not (current_user.has_edit_permission or current_user.is_dataset_operator):
if not (current_user.is_editor or current_user.is_dataset_operator):
raise Forbidden() raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id) ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
return {"result": "success"}, 204 return {"result": "success"}, 204
@ -199,21 +202,24 @@ class ExternalDatasetCreateApi(Resource):
@account_initialization_required @account_initialization_required
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json") .add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
parser.add_argument( .add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="name is required. Name must be between 1 to 100 characters.", required=True,
type=_validate_name, help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument("description", type=str, required=False, nullable=True, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
) )
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -223,7 +229,7 @@ class ExternalDatasetCreateApi(Resource):
try: try:
dataset = ExternalDatasetService.create_external_dataset( dataset = ExternalDatasetService.create_external_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
user_id=current_user.id, user_id=current_user.id,
args=args, args=args,
) )
@ -255,6 +261,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -265,10 +272,12 @@ class ExternalKnowledgeHitTestingApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
parser = reqparse.RequestParser() parser = (
parser.add_argument("query", type=str, location="json") reqparse.RequestParser()
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") .add_argument("query", type=str, location="json")
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json") .add_argument("external_retrieval_model", type=dict, required=False, location="json")
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
HitTestingService.hit_testing_args_check(args) HitTestingService.hit_testing_args_check(args)
@ -277,7 +286,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
response = HitTestingService.external_retrieve( response = HitTestingService.external_retrieve(
dataset=dataset, dataset=dataset,
query=args["query"], query=args["query"],
account=cast(Account, current_user), account=current_user,
external_retrieval_model=args["external_retrieval_model"], external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"], metadata_filtering_conditions=args["metadata_filtering_conditions"],
) )
@ -304,15 +313,17 @@ class BedrockRetrievalApi(Resource):
) )
@api.response(200, "Bedrock retrieval test completed") @api.response(200, "Bedrock retrieval test completed")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
"query", .add_argument(
nullable=False, "query",
required=True, nullable=False,
type=str, required=True,
type=str,
)
.add_argument("knowledge_id", nullable=False, required=True, type=str)
) )
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
args = parser.parse_args() args = parser.parse_args()
# Call the knowledge retrieval service # Call the knowledge retrieval service

View File

@ -48,11 +48,12 @@ class DatasetsHitTestingBase:
@staticmethod @staticmethod
def parse_args(): def parse_args():
parser = reqparse.RequestParser() parser = (
reqparse.RequestParser()
parser.add_argument("query", type=str, location="json") .add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json") .add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") .add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
return parser.parse_args() return parser.parse_args()
@staticmethod @staticmethod

View File

@ -1,13 +1,12 @@
from typing import Literal from typing import Literal
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields from fields.dataset_fields import dataset_metadata_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import ( from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs, MetadataArgs,
@ -24,9 +23,12 @@ class DatasetMetadataCreateApi(Resource):
@enterprise_license_required @enterprise_license_required
@marshal_with(dataset_metadata_fields) @marshal_with(dataset_metadata_fields)
def post(self, dataset_id): def post(self, dataset_id):
parser = reqparse.RequestParser() current_user, _ = current_account_with_tenant()
parser.add_argument("type", type=str, required=True, nullable=False, location="json") parser = (
parser.add_argument("name", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
metadata_args = MetadataArgs.model_validate(args) metadata_args = MetadataArgs.model_validate(args)
@ -59,8 +61,8 @@ class DatasetMetadataApi(Resource):
@enterprise_license_required @enterprise_license_required
@marshal_with(dataset_metadata_fields) @marshal_with(dataset_metadata_fields)
def patch(self, dataset_id, metadata_id): def patch(self, dataset_id, metadata_id):
parser = reqparse.RequestParser() current_user, _ = current_account_with_tenant()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
name = args["name"] name = args["name"]
@ -79,6 +81,7 @@ class DatasetMetadataApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def delete(self, dataset_id, metadata_id): def delete(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id) metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -108,6 +111,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def post(self, dataset_id, action: Literal["enable", "disable"]): def post(self, dataset_id, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -128,14 +132,16 @@ class DocumentMetadataEditApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") "operation_data", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
metadata_args = MetadataOperationData.model_validate(args) metadata_args = MetadataOperationData.model_validate(args)

View File

@ -4,10 +4,7 @@ from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config from configs import dify_config
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
account_initialization_required,
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
@ -23,12 +20,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, provider_id: str): def get(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id tenant_id = current_tenant_id
if not current_user.has_edit_permission:
raise Forbidden()
credential_id = request.args.get("credential_id") credential_id = request.args.get("credential_id")
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
@ -130,17 +126,17 @@ class DatasourceAuth(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission: parser = (
raise Forbidden() reqparse.RequestParser()
.add_argument(
parser = reqparse.RequestParser() "name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
parser.add_argument( )
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
) )
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
@ -177,16 +173,17 @@ class DatasourceAuthDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name provider_name = datasource_provider_id.provider_name
if not current_user.has_edit_permission:
raise Forbidden() parser = reqparse.RequestParser().add_argument(
parser = reqparse.RequestParser() "credential_id", type=str, required=True, nullable=False, location="json"
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") )
args = parser.parse_args() args = parser.parse_args()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials( datasource_provider_service.remove_datasource_credentials(
@ -203,17 +200,19 @@ class DatasourceAuthUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") reqparse.RequestParser()
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json") .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") .add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if not current_user.has_edit_permission:
raise Forbidden()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials( datasource_provider_service.update_datasource_credentials(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
@ -257,14 +256,15 @@ class DatasourceAuthOauthCustomClient(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission: parser = (
raise Forbidden() reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") .add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
@ -296,13 +296,11 @@ class DatasourceAuthDefaultApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission: parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
@ -319,14 +317,15 @@ class DatasourceUpdateProviderNameApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission: parser = (
raise Forbidden() reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json") .add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") )
args = parser.parse_args() args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()

View File

@ -26,10 +26,12 @@ class DataSourceContentPreviewApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json") .add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
inputs = args.get("inputs") inputs = args.get("inputs")

View File

@ -66,26 +66,28 @@ class CustomizedPipelineTemplateApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def patch(self, template_id: str): def patch(self, template_id: str):
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="Name must be between 1 to 40 characters.", required=True,
type=_validate_name, help="Name must be between 1 to 40 characters.",
) type=_validate_name,
parser.add_argument( )
"description", .add_argument(
type=_validate_description_length, "description",
nullable=True, type=_validate_description_length,
required=False, nullable=True,
default="", required=False,
) default="",
parser.add_argument( )
"icon_info", .add_argument(
type=dict, "icon_info",
location="json", type=dict,
nullable=True, location="json",
nullable=True,
)
) )
args = parser.parse_args() args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args) pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
@ -123,26 +125,28 @@ class PublishCustomizedPipelineTemplateApi(Resource):
@enterprise_license_required @enterprise_license_required
@knowledge_pipeline_publish_enabled @knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str): def post(self, pipeline_id: str):
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="Name must be between 1 to 40 characters.", required=True,
type=_validate_name, help="Name must be between 1 to 40 characters.",
) type=_validate_name,
parser.add_argument( )
"description", .add_argument(
type=_validate_description_length, "description",
nullable=True, type=_validate_description_length,
required=False, nullable=True,
default="", required=False,
) default="",
parser.add_argument( )
"icon_info", .add_argument(
type=dict, "icon_info",
location="json", type=dict,
nullable=True, location="json",
nullable=True,
)
) )
args = parser.parse_args() args = parser.parse_args()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()

View File

@ -26,9 +26,7 @@ class CreateRagPipelineDatasetApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"yaml_content", "yaml_content",
type=str, type=str,
nullable=False, nullable=False,

View File

@ -23,7 +23,7 @@ from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models.account import Account from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from models.workflow import WorkflowDraftVariable from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -33,16 +33,18 @@ logger = logging.getLogger(__name__)
def _create_pagination_parser(): def _create_pagination_parser():
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"page", .add_argument(
type=inputs.int_range(1, 100_000), "page",
required=False, type=inputs.int_range(1, 100_000),
default=1, required=False,
location="args", default=1,
help="the page of data requested", location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
) )
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser return parser
@ -206,10 +208,11 @@ class RagPipelineVariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# } # }
parser = reqparse.RequestParser() parser = (
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") reqparse.RequestParser()
# Parse 'value' field as-is to maintain its original data structure .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=db.session(), session=db.session(),

View File

@ -1,6 +1,3 @@
from typing import cast
from flask_login import current_user # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -13,8 +10,7 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@ -28,26 +24,29 @@ class RagPipelineImportApi(Resource):
@marshal_with(pipeline_import_fields) @marshal_with(pipeline_import_fields)
def post(self): def post(self):
# Check user role first # Check user role first
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("mode", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("yaml_content", type=str, location="json") .add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_url", type=str, location="json") .add_argument("yaml_content", type=str, location="json")
parser.add_argument("name", type=str, location="json") .add_argument("yaml_url", type=str, location="json")
parser.add_argument("description", type=str, location="json") .add_argument("name", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("description", type=str, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("icon_type", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json") .add_argument("icon", type=str, location="json")
parser.add_argument("pipeline_id", type=str, location="json") .add_argument("icon_background", type=str, location="json")
.add_argument("pipeline_id", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = RagPipelineDslService(session) import_service = RagPipelineDslService(session)
# Import app # Import app
account = cast(Account, current_user) account = current_user
result = import_service.import_rag_pipeline( result = import_service.import_rag_pipeline(
account=account, account=account,
import_mode=args["mode"], import_mode=args["mode"],
@ -74,15 +73,16 @@ class RagPipelineImportConfirmApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(pipeline_import_fields) @marshal_with(pipeline_import_fields)
def post(self, import_id): def post(self, import_id):
current_user, _ = current_account_with_tenant()
# Check user role first # Check user role first
if not current_user.is_editor: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = RagPipelineDslService(session) import_service = RagPipelineDslService(session)
# Confirm import # Confirm import
account = cast(Account, current_user) account = current_user
result = import_service.confirm_import(import_id=import_id, account=account) result = import_service.confirm_import(import_id=import_id, account=account)
session.commit() session.commit()
@ -100,7 +100,8 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(pipeline_import_check_dependencies_fields) @marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
with Session(db.engine) as session: with Session(db.engine) as session:
@ -117,12 +118,12 @@ class RagPipelineExportApi(Resource):
@get_rag_pipeline @get_rag_pipeline
@account_initialization_required @account_initialization_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
# Add include_secret params # Add include_secret params
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
parser.add_argument("include_secret", type=str, default="false", location="args")
args = parser.parse_args() args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:

View File

@ -18,6 +18,7 @@ from controllers.console.app.error import (
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
edit_permission_required,
setup_required, setup_required,
) )
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
@ -36,8 +37,8 @@ from fields.workflow_run_fields import (
) )
from libs import helper from libs import helper
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, current_user, login_required
from models.account import Account from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from models.model import EndUser from models.model import EndUser
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
@ -56,15 +57,12 @@ class DraftRagPipelineApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get draft rag pipeline's workflow Get draft rag pipeline's workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model # fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
@ -79,23 +77,25 @@ class DraftRagPipelineApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline): def post(self, pipeline: Pipeline):
""" """
Sync draft workflow Sync draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
raise Forbidden()
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type: if "application/json" in content_type:
parser = reqparse.RequestParser() parser = (
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("hash", type=str, required=False, location="json") .add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("environment_variables", type=list, required=False, location="json") .add_argument("hash", type=str, required=False, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json") .add_argument("environment_variables", type=list, required=False, location="json")
parser.add_argument("rag_pipeline_variables", type=list, required=False, location="json") .add_argument("conversation_variables", type=list, required=False, location="json")
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
@ -154,16 +154,15 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline, node_id: str): def post(self, pipeline: Pipeline, node_id: str):
""" """
Run draft workflow iteration node Run draft workflow iteration node
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -194,11 +193,11 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
Run draft workflow loop node Run draft workflow loop node
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -229,14 +228,17 @@ class DraftRagPipelineRunApi(Resource):
Run draft workflow Run draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_info_list", type=list, required=True, location="json") .add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json") .add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -264,17 +266,20 @@ class PublishedRagPipelineRunApi(Resource):
Run published workflow Run published workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_info_list", type=list, required=True, location="json") .add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json") .add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False) .add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming") .add_argument("is_preview", type=bool, required=True, location="json", default=False)
parser.add_argument("original_document_id", type=str, required=False, location="json") .add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
@ -303,15 +308,16 @@ class PublishedRagPipelineRunApi(Resource):
# Run rag pipeline datasource # Run rag pipeline datasource
# """ # """
# # The role of the current user in the ta table must be admin, owner, or editor # # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.is_editor: # if not current_user.has_edit_permission:
# raise Forbidden() # raise Forbidden()
# #
# if not isinstance(current_user, Account): # if not isinstance(current_user, Account):
# raise Forbidden() # raise Forbidden()
# #
# parser = reqparse.RequestParser() # parser = (reqparse.RequestParser()
# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") # .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# parser.add_argument("datasource_type", type=str, required=True, location="json") # .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args() # args = parser.parse_args()
# #
# job_id = args.get("job_id") # job_id = args.get("job_id")
@ -344,15 +350,16 @@ class PublishedRagPipelineRunApi(Resource):
# Run rag pipeline datasource # Run rag pipeline datasource
# """ # """
# # The role of the current user in the ta table must be admin, owner, or editor # # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.is_editor: # if not current_user.has_edit_permission:
# raise Forbidden() # raise Forbidden()
# #
# if not isinstance(current_user, Account): # if not isinstance(current_user, Account):
# raise Forbidden() # raise Forbidden()
# #
# parser = reqparse.RequestParser() # parser = (reqparse.RequestParser()
# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") # .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# parser.add_argument("datasource_type", type=str, required=True, location="json") # .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args() # args = parser.parse_args()
# #
# job_id = args.get("job_id") # job_id = args.get("job_id")
@ -385,13 +392,16 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
Run rag pipeline datasource Run rag pipeline datasource
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json") .add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
inputs = args.get("inputs") inputs = args.get("inputs")
@ -428,13 +438,16 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
Run rag pipeline datasource Run rag pipeline datasource
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json") .add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
inputs = args.get("inputs") inputs = args.get("inputs")
@ -472,11 +485,13 @@ class RagPipelineDraftNodeRunApi(Resource):
Run draft workflow node Run draft workflow node
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") "inputs", type=dict, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
inputs = args.get("inputs") inputs = args.get("inputs")
@ -505,7 +520,8 @@ class RagPipelineTaskStopApi(Resource):
Stop workflow task Stop workflow task
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
@ -525,7 +541,8 @@ class PublishedRagPipelineApi(Resource):
Get published pipeline Get published pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
if not pipeline.is_published: if not pipeline.is_published:
return None return None
@ -545,7 +562,8 @@ class PublishedRagPipelineApi(Resource):
Publish workflow Publish workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
@ -580,7 +598,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
Get default block config Get default block config
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
# Get default block configs # Get default block configs
@ -599,11 +618,11 @@ class DefaultRagPipelineBlockConfigApi(Resource):
Get default block config Get default block config
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
parser.add_argument("q", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
q = args.get("q") q = args.get("q")
@ -631,14 +650,17 @@ class PublishedAllRagPipelineApi(Resource):
""" """
Get published workflows Get published workflows
""" """
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("user_id", type=str, required=False, location="args") .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args") .add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
args = parser.parse_args() args = parser.parse_args()
page = int(args.get("page", 1)) page = int(args.get("page", 1))
limit = int(args.get("limit", 10)) limit = int(args.get("limit", 10))
@ -681,12 +703,15 @@ class RagPipelineByIdApi(Resource):
Update workflow attributes Update workflow attributes
""" """
# Check permission # Check permission
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("marked_name", type=str, required=False, location="json") reqparse.RequestParser()
parser.add_argument("marked_comment", type=str, required=False, location="json") .add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# Validate name and comment length # Validate name and comment length
@ -733,15 +758,12 @@ class PublishedRagPipelineSecondStepApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get second step parameters of rag pipeline Get second step parameters of rag pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
node_id = args.get("node_id") node_id = args.get("node_id")
if not node_id: if not node_id:
@ -759,15 +781,12 @@ class PublishedRagPipelineFirstStepApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get first step parameters of rag pipeline Get first step parameters of rag pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
node_id = args.get("node_id") node_id = args.get("node_id")
if not node_id: if not node_id:
@ -785,15 +804,12 @@ class DraftRagPipelineFirstStepApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get first step parameters of rag pipeline Get first step parameters of rag pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
node_id = args.get("node_id") node_id = args.get("node_id")
if not node_id: if not node_id:
@ -811,15 +827,12 @@ class DraftRagPipelineSecondStepApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get second step parameters of rag pipeline Get second step parameters of rag pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
node_id = args.get("node_id") node_id = args.get("node_id")
if not node_id: if not node_id:
@ -843,9 +856,11 @@ class RagPipelineWorkflowRunListApi(Resource):
""" """
Get workflow run list Get workflow run list
""" """
parser = reqparse.RequestParser() parser = (
parser.add_argument("last_id", type=uuid_value, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") .add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args() args = parser.parse_args()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
@ -880,7 +895,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@marshal_with(workflow_run_node_execution_list_fields) @marshal_with(workflow_run_node_execution_list_fields)
def get(self, pipeline: Pipeline, run_id): def get(self, pipeline: Pipeline, run_id: str):
""" """
Get workflow run node execution list Get workflow run node execution list
""" """
@ -903,14 +918,8 @@ class DatasourceListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user _, current_tenant_id = current_account_with_tenant()
if not isinstance(user, Account): return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(current_tenant_id))
raise Forbidden()
tenant_id = user.current_tenant_id
if not tenant_id:
raise Forbidden()
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
@ -940,9 +949,8 @@ class RagPipelineTransformApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, dataset_id): def post(self, dataset_id: str):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden()
if not (current_user.has_edit_permission or current_user.is_dataset_operator): if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden() raise Forbidden()
@ -959,19 +967,20 @@ class RagPipelineDatasourceVariableApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_run_node_execution_fields) @marshal_with(workflow_run_node_execution_fields)
def post(self, pipeline: Pipeline): def post(self, pipeline: Pipeline):
""" """
Set datasource variables Set datasource variables
""" """
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json") .add_argument("datasource_info", type=dict, required=True, location="json")
parser.add_argument("datasource_info", type=dict, required=True, location="json") .add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json") .add_argument("start_node_title", type=str, required=True, location="json")
parser.add_argument("start_node_title", type=str, required=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()

View File

@ -31,17 +31,19 @@ class WebsiteCrawlApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"provider", .add_argument(
type=str, "provider",
choices=["firecrawl", "watercrawl", "jinareader"], type=str,
required=True, choices=["firecrawl", "watercrawl", "jinareader"],
nullable=True, required=True,
location="json", nullable=True,
location="json",
)
.add_argument("url", type=str, required=True, nullable=True, location="json")
.add_argument("options", type=dict, required=True, nullable=True, location="json")
) )
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
# Create typed request and validate # Create typed request and validate
@ -70,8 +72,7 @@ class WebsiteCrawlStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, job_id: str): def get(self, job_id: str):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args" "provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -3,8 +3,7 @@ from functools import wraps
from controllers.console.datasets.error import PipelineNotFoundError from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user from libs.login import current_account_with_tenant
from models.account import Account
from models.dataset import Pipeline from models.dataset import Pipeline
@ -17,8 +16,7 @@ def get_rag_pipeline(
if not kwargs.get("pipeline_id"): if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters") raise ValueError("missing pipeline_id in path parameters")
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("current_user is not an account")
pipeline_id = kwargs.get("pipeline_id") pipeline_id = kwargs.get("pipeline_id")
pipeline_id = str(pipeline_id) pipeline_id = str(pipeline_id)
@ -27,7 +25,7 @@ def get_rag_pipeline(
pipeline = ( pipeline = (
db.session.query(Pipeline) db.session.query(Pipeline)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id) .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
.first() .first()
) )

View File

@ -81,11 +81,13 @@ class ChatTextApi(InstalledAppResource):
app_model = installed_app.app app_model = installed_app.app
try: try:
parser = reqparse.RequestParser() parser = (
parser.add_argument("message_id", type=str, required=False, location="json") reqparse.RequestParser()
parser.add_argument("voice", type=str, location="json") .add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("text", type=str, location="json") .add_argument("voice", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json") .add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args() args = parser.parse_args()
message_id = args.get("message_id", None) message_id = args.get("message_id", None)

View File

@ -49,12 +49,14 @@ class CompletionApi(InstalledAppResource):
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, location="json") reqparse.RequestParser()
parser.add_argument("query", type=str, location="json", default="") .add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json") .add_argument("query", type=str, location="json", default="")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") .add_argument("files", type=list, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
@ -121,13 +123,15 @@ class ChatApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, location="json") reqparse.RequestParser()
parser.add_argument("query", type=str, required=True, location="json") .add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json") .add_argument("query", type=str, required=True, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json") .add_argument("files", type=list, required=False, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") .add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") .add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args() args = parser.parse_args()
args["auto_generate_name"] = False args["auto_generate_name"] = False

View File

@ -31,10 +31,12 @@ class ConversationListApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("last_id", type=uuid_value, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") .add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
)
args = parser.parse_args() args = parser.parse_args()
pinned = None pinned = None
@ -94,9 +96,11 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id) conversation_id = str(c_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("name", type=str, required=False, location="json") reqparse.RequestParser()
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") .add_argument("name", type=str, required=False, location="json")
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:

View File

@ -15,7 +15,6 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -67,31 +66,26 @@ class InstalledAppsListApi(Resource):
# Pre-filter out apps without setting or with sso_verified # Pre-filter out apps without setting or with sso_verified
filtered_installed_apps = [] filtered_installed_apps = []
app_id_to_app_code = {}
for installed_app in installed_app_list: for installed_app in installed_app_list:
app_id = installed_app["app"].id app_id = installed_app["app"].id
webapp_setting = webapp_settings.get(app_id) webapp_setting = webapp_settings.get(app_id)
if not webapp_setting or webapp_setting.access_mode == "sso_verified": if not webapp_setting or webapp_setting.access_mode == "sso_verified":
continue continue
app_code = AppService.get_app_code_by_id(str(app_id))
app_id_to_app_code[app_id] = app_code
filtered_installed_apps.append(installed_app) filtered_installed_apps.append(installed_app)
app_codes = list(app_id_to_app_code.values())
# Batch permission check # Batch permission check
app_ids = [installed_app["app"].id for installed_app in filtered_installed_apps]
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps( permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
user_id=user_id, user_id=user_id,
app_codes=app_codes, app_ids=app_ids,
) )
# Keep only allowed apps # Keep only allowed apps
res = [] res = []
for installed_app in filtered_installed_apps: for installed_app in filtered_installed_apps:
app_id = installed_app["app"].id app_id = installed_app["app"].id
app_code = app_id_to_app_code[app_id] if permissions.get(app_id):
if permissions.get(app_code):
res.append(installed_app) res.append(installed_app)
installed_app_list = res installed_app_list = res
@ -111,8 +105,7 @@ class InstalledAppsListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args() args = parser.parse_args()
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
@ -170,8 +163,7 @@ class InstalledAppApi(InstalledAppResource):
return {"result": "success", "message": "App uninstalled successfully"}, 204 return {"result": "success", "message": "App uninstalled successfully"}, 204
def patch(self, installed_app): def patch(self, installed_app):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
parser.add_argument("is_pinned", type=inputs.boolean)
args = parser.parse_args() args = parser.parse_args()
commit_args = False commit_args = False

View File

@ -23,8 +23,7 @@ from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import current_user from libs.login import current_account_with_tenant
from models import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError from services.errors.app import MoreLikeThisDisabledError
@ -48,21 +47,22 @@ logger = logging.getLogger(__name__)
class MessageListApi(InstalledAppResource): class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") reqparse.RequestParser()
parser.add_argument("first_id", type=uuid_value, location="args") .add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") .add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return MessageService.pagination_by_first_id( return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
) )
@ -78,18 +78,19 @@ class MessageListApi(InstalledAppResource):
) )
class MessageFeedbackApi(InstalledAppResource): class MessageFeedbackApi(InstalledAppResource):
def post(self, installed_app, message_id): def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") reqparse.RequestParser()
parser.add_argument("content", type=str, location="json") .add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
.add_argument("content", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
MessageService.create_feedback( MessageService.create_feedback(
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
@ -109,14 +110,14 @@ class MessageFeedbackApi(InstalledAppResource):
) )
class MessageMoreLikeThisApi(InstalledAppResource): class MessageMoreLikeThisApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
) )
args = parser.parse_args() args = parser.parse_args()
@ -124,8 +125,6 @@ class MessageMoreLikeThisApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate_more_like_this( response = AppGenerateService.generate_more_like_this(
app_model=app_model, app_model=app_model,
user=current_user, user=current_user,
@ -159,6 +158,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
) )
class MessageSuggestedQuestionApi(InstalledAppResource): class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -167,8 +167,6 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
message_id = str(message_id) message_id = str(message_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer( questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
) )

View File

@ -42,8 +42,7 @@ class RecommendedAppListApi(Resource):
@marshal_with(recommended_app_list_fields) @marshal_with(recommended_app_list_fields)
def get(self): def get(self):
# language args # language args
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("language", type=str, location="args")
parser.add_argument("language", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
language = args.get("language") language = args.get("language")

View File

@ -7,8 +7,7 @@ from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user from libs.login import current_account_with_tenant
from models import Account
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService from services.saved_message_service import SavedMessageService
@ -35,31 +34,30 @@ class SavedMessageListApi(InstalledAppResource):
@marshal_with(saved_message_infinite_scroll_pagination_fields) @marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("last_id", type=uuid_value, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") .add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args() args = parser.parse_args()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
def post(self, installed_app): def post(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
parser.add_argument("message_id", type=uuid_value, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.save(app_model, current_user, args["message_id"]) SavedMessageService.save(app_model, current_user, args["message_id"])
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -72,6 +70,7 @@ class SavedMessageListApi(InstalledAppResource):
) )
class SavedMessageApi(InstalledAppResource): class SavedMessageApi(InstalledAppResource):
def delete(self, installed_app, message_id): def delete(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
message_id = str(message_id) message_id = str(message_id)
@ -79,8 +78,6 @@ class SavedMessageApi(InstalledAppResource):
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.delete(app_model, current_user, message_id) SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -22,7 +22,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper from libs import helper
from libs.login import current_user from libs.login import current_user as current_user_
from models.model import AppMode, InstalledApp from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
@ -31,6 +31,8 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
current_user = current_user_._get_current_object() # type: ignore
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run") @console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource): class InstalledAppWorkflowRunApi(InstalledAppResource):
@ -45,9 +47,11 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("files", type=list, required=False, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
assert current_user is not None assert current_user is not None
try: try:

View File

@ -8,10 +8,8 @@ from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models import InstalledApp from models import InstalledApp
from models.account import Account
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -24,13 +22,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
def decorator(view: Callable[Concatenate[InstalledApp, P], R]): def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.where( .where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
)
.first() .first()
) )
@ -56,14 +51,13 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
def decorator(view: Callable[Concatenate[InstalledApp, P], R]): def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
feature = FeatureService.get_system_features() feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled: if feature.webapp_auth.enabled:
assert isinstance(current_user, Account)
app_id = installed_app.app_id app_id = installed_app.app_id
app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=str(current_user.id), user_id=str(current_user.id),
app_code=app_code, app_id=app_id,
) )
if not res: if not res:
raise AppAccessDeniedError() raise AppAccessDeniedError()

View File

@ -4,8 +4,7 @@ from constants import HIDDEN_VALUE
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_account_with_tenant, current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from models.api_based_extension import APIBasedExtension from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService
@ -30,8 +29,7 @@ class CodeBasedExtensionAPI(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args")
parser.add_argument("module", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
@ -68,12 +66,12 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_fields) @marshal_with(api_based_extension_fields)
def post(self): def post(self):
assert isinstance(current_user, Account) parser = (
assert current_user.current_tenant_id is not None reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("name", type=str, required=True, location="json")
parser.add_argument("name", type=str, required=True, location="json") .add_argument("api_endpoint", type=str, required=True, location="json")
parser.add_argument("api_endpoint", type=str, required=True, location="json") .add_argument("api_key", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
@ -98,8 +96,6 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_fields) @marshal_with(api_based_extension_fields)
def get(self, id): def get(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id) api_based_extension_id = str(id)
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
@ -124,17 +120,17 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_fields) @marshal_with(api_based_extension_fields)
def post(self, id): def post(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id) api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("name", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("api_endpoint", type=str, required=True, location="json") .add_argument("name", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json") .add_argument("api_endpoint", type=str, required=True, location="json")
.add_argument("api_key", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
extension_data_from_db.name = args["name"] extension_data_from_db.name = args["name"]
@ -153,8 +149,6 @@ class APIBasedExtensionDetailAPI(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, id): def delete(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id) api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()

View File

@ -1,7 +1,6 @@
from typing import Literal from typing import Literal
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with from flask_restx import Resource, marshal_with
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -22,8 +21,7 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.file_fields import file_fields, upload_config_fields from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account
from services.file_service import FileService from services.file_service import FileService
from . import console_ns from . import console_ns
@ -53,6 +51,7 @@ class FileApi(Resource):
@marshal_with(file_fields) @marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents") @cloud_edition_billing_resource_check("documents")
def post(self): def post(self):
current_user, _ = current_account_with_tenant()
source_str = request.form.get("source") source_str = request.form.get("source")
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
@ -65,16 +64,12 @@ class FileApi(Resource):
if not file.filename: if not file.filename:
raise FilenameNotExistsError raise FilenameNotExistsError
if source == "datasets" and not current_user.is_dataset_editor: if source == "datasets" and not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
if source not in ("datasets", None): if source not in ("datasets", None):
source = None source = None
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
try: try:
upload_file = FileService(db.engine).upload_file( upload_file = FileService(db.engine).upload_file(
filename=file.filename, filename=file.filename,

View File

@ -57,8 +57,7 @@ class InitValidateAPI(Resource):
if tenant_count > 0: if tenant_count > 0:
raise AlreadySetupError() raise AlreadySetupError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json")
parser.add_argument("password", type=StrLen(30), required=True, location="json")
input_password = parser.parse_args()["password"] input_password = parser.parse_args()["password"]
if input_password != os.environ.get("INIT_PASSWORD"): if input_password != os.environ.get("INIT_PASSWORD"):

View File

@ -40,8 +40,7 @@ class RemoteFileInfoApi(Resource):
class RemoteFileUploadApi(Resource): class RemoteFileUploadApi(Resource):
@marshal_with(file_fields_with_signed_url) @marshal_with(file_fields_with_signed_url)
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
parser.add_argument("url", type=str, required=True, help="URL is required")
args = parser.parse_args() args = parser.parse_args()
url = args["url"] url = args["url"]

View File

@ -69,10 +69,12 @@ class SetupApi(Resource):
if not get_init_validate_status(): if not get_init_validate_status():
raise NotInitValidateError() raise NotInitValidateError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("name", type=StrLen(30), required=True, location="json") .add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json") .add_argument("name", type=StrLen(30), required=True, location="json")
.add_argument("password", type=valid_password, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# setup # setup

View File

@ -5,8 +5,7 @@ from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields from fields.tag_fields import dataset_tag_fields
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from models.model import Tag from models.model import Tag
from services.tag_service import TagService from services.tag_service import TagService
@ -24,11 +23,10 @@ class TagListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(dataset_tag_fields) @marshal_with(dataset_tag_fields)
def get(self): def get(self):
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
tag_type = request.args.get("type", type=str, default="") tag_type = request.args.get("type", type=str, default="")
keyword = request.args.get("keyword", default=None, type=str) keyword = request.args.get("keyword", default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) tags = TagService.get_tags(tag_type, current_tenant_id, keyword)
return tags, 200 return tags, 200
@ -36,18 +34,23 @@ class TagListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
assert isinstance(current_user, Account) current_user, _ = current_account_with_tenant()
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name .add_argument(
) "name",
parser.add_argument( nullable=False,
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." required=True,
help="Name must be between 1 to 50 characters.",
type=_validate_name,
)
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
) )
args = parser.parse_args() args = parser.parse_args()
tag = TagService.save_tags(args) tag = TagService.save_tags(args)
@ -63,15 +66,13 @@ class TagUpdateDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, tag_id): def patch(self, tag_id):
assert isinstance(current_user, Account) current_user, _ = current_account_with_tenant()
assert current_user.current_tenant_id is not None
tag_id = str(tag_id) tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
) )
args = parser.parse_args() args = parser.parse_args()
@ -87,8 +88,7 @@ class TagUpdateDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, tag_id): def delete(self, tag_id):
assert isinstance(current_user, Account) current_user, _ = current_account_with_tenant()
assert current_user.current_tenant_id is not None
tag_id = str(tag_id) tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission: if not current_user.has_edit_permission:
@ -105,21 +105,22 @@ class TagBindingCreateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
assert isinstance(current_user, Account) current_user, _ = current_account_with_tenant()
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." .add_argument(
) "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
parser.add_argument( )
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required." .add_argument(
) "target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
parser.add_argument( )
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." .add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
) )
args = parser.parse_args() args = parser.parse_args()
TagService.save_tag_binding(args) TagService.save_tag_binding(args)
@ -133,17 +134,18 @@ class TagBindingDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
assert isinstance(current_user, Account) current_user, _ = current_account_with_tenant()
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") reqparse.RequestParser()
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") .add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
parser.add_argument( .add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." .add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
) )
args = parser.parse_args() args = parser.parse_args()
TagService.delete_tag_binding(args) TagService.delete_tag_binding(args)

View File

@ -37,8 +37,7 @@ class VersionApi(Resource):
) )
def get(self): def get(self):
"""Check for application version updates""" """Check for application version updates"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("current_version", type=str, required=True, location="args")
parser.add_argument("current_version", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
check_update_url = dify_config.CHECK_UPDATE_URL check_update_url = dify_config.CHECK_UPDATE_URL

View File

@ -2,11 +2,11 @@ from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar from typing import ParamSpec, TypeVar
from flask_login import current_user
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.account import TenantPluginPermission from models.account import TenantPluginPermission
P = ParamSpec("P") P = ParamSpec("P")
@ -20,8 +20,9 @@ def plugin_permission_required(
def interceptor(view: Callable[P, R]): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user user = current_user
tenant_id = user.current_tenant_id tenant_id = current_tenant_id
with Session(db.engine) as session: with Session(db.engine) as session:
permission = ( permission = (

View File

@ -2,7 +2,6 @@ from datetime import datetime
import pytz import pytz
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -37,9 +36,8 @@ from extensions.ext_database import db
from fields.member_fields import account_fields from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, email, extract_remote_ip, timezone from libs.helper import TimestampField, email, extract_remote_ip, timezone
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode from models import Account, AccountIntegrate, InvitationCode
from models.account import Account
from services.account_service import AccountService from services.account_service import AccountService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@ -50,9 +48,7 @@ class AccountInitApi(Resource):
@setup_required @setup_required
@login_required @login_required
def post(self): def post(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
if account.status == "active": if account.status == "active":
raise AccountAlreadyInitedError() raise AccountAlreadyInitedError()
@ -61,9 +57,9 @@ class AccountInitApi(Resource):
if dify_config.EDITION == "CLOUD": if dify_config.EDITION == "CLOUD":
parser.add_argument("invitation_code", type=str, location="json") parser.add_argument("invitation_code", type=str, location="json")
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
parser.add_argument("interface_language", type=supported_language, required=True, location="json") "timezone", type=timezone, required=True, location="json"
parser.add_argument("timezone", type=timezone, required=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
if dify_config.EDITION == "CLOUD": if dify_config.EDITION == "CLOUD":
@ -106,8 +102,7 @@ class AccountProfileApi(Resource):
@marshal_with(account_fields) @marshal_with(account_fields)
@enterprise_license_required @enterprise_license_required
def get(self): def get(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
return current_user return current_user
@ -118,10 +113,8 @@ class AccountNameApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
# Validate account name length # Validate account name length
@ -140,10 +133,8 @@ class AccountAvatarApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
parser = reqparse.RequestParser()
parser.add_argument("avatar", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
@ -158,10 +149,10 @@ class AccountInterfaceLanguageApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument(
parser = reqparse.RequestParser() "interface_language", type=supported_language, required=True, location="json"
parser.add_argument("interface_language", type=supported_language, required=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
@ -176,10 +167,10 @@ class AccountInterfaceThemeApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument(
parser = reqparse.RequestParser() "interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
@ -194,10 +185,8 @@ class AccountTimezoneApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
parser = reqparse.RequestParser()
parser.add_argument("timezone", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
# Validate timezone string, e.g. America/New_York, Asia/Shanghai # Validate timezone string, e.g. America/New_York, Asia/Shanghai
@ -216,12 +205,13 @@ class AccountPasswordApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = (
parser = reqparse.RequestParser() reqparse.RequestParser()
parser.add_argument("password", type=str, required=False, location="json") .add_argument("password", type=str, required=False, location="json")
parser.add_argument("new_password", type=str, required=True, location="json") .add_argument("new_password", type=str, required=True, location="json")
parser.add_argument("repeat_new_password", type=str, required=True, location="json") .add_argument("repeat_new_password", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if args["new_password"] != args["repeat_new_password"]: if args["new_password"] != args["repeat_new_password"]:
@ -253,9 +243,7 @@ class AccountIntegrateApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(integrate_list_fields) @marshal_with(integrate_list_fields)
def get(self): def get(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
account_integrates = db.session.scalars( account_integrates = db.session.scalars(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id) select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
@ -298,9 +286,7 @@ class AccountDeleteVerifyApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
token, code = AccountService.generate_account_deletion_verification_code(account) token, code = AccountService.generate_account_deletion_verification_code(account)
AccountService.send_account_deletion_verification_email(account, code) AccountService.send_account_deletion_verification_email(account, code)
@ -314,13 +300,13 @@ class AccountDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
parser = reqparse.RequestParser() parser = (
parser.add_argument("token", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") .add_argument("token", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if not AccountService.verify_account_deletion_code(args["token"], args["code"]): if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
@ -335,9 +321,11 @@ class AccountDeleteApi(Resource):
class AccountDeleteUpdateFeedbackApi(Resource): class AccountDeleteUpdateFeedbackApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("feedback", type=str, required=True, location="json") .add_argument("email", type=str, required=True, location="json")
.add_argument("feedback", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
BillingService.update_account_deletion_feedback(args["email"], args["feedback"]) BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
@ -358,9 +346,7 @@ class EducationVerifyApi(Resource):
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
@marshal_with(verify_fields) @marshal_with(verify_fields)
def get(self): def get(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
return BillingService.EducationIdentity.verify(account.id, account.email) return BillingService.EducationIdentity.verify(account.id, account.email)
@ -380,14 +366,14 @@ class EducationApi(Resource):
@only_edition_cloud @only_edition_cloud
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
def post(self): def post(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
parser = reqparse.RequestParser() parser = (
parser.add_argument("token", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("institution", type=str, required=True, location="json") .add_argument("token", type=str, required=True, location="json")
parser.add_argument("role", type=str, required=True, location="json") .add_argument("institution", type=str, required=True, location="json")
.add_argument("role", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"]) return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
@ -399,9 +385,7 @@ class EducationApi(Resource):
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
@marshal_with(status_fields) @marshal_with(status_fields)
def get(self): def get(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
res = BillingService.EducationIdentity.status(account.id) res = BillingService.EducationIdentity.status(account.id)
# convert expire_at to UTC timestamp from isoformat # convert expire_at to UTC timestamp from isoformat
@ -425,10 +409,12 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
@marshal_with(data_fields) @marshal_with(data_fields)
def get(self): def get(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("keywords", type=str, required=True, location="args") reqparse.RequestParser()
parser.add_argument("page", type=int, required=False, location="args", default=0) .add_argument("keywords", type=str, required=True, location="args")
parser.add_argument("limit", type=int, required=False, location="args", default=20) .add_argument("page", type=int, required=False, location="args", default=0)
.add_argument("limit", type=int, required=False, location="args", default=20)
)
args = parser.parse_args() args = parser.parse_args()
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"]) return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
@ -441,11 +427,14 @@ class ChangeEmailSendEmailApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() current_user, _ = current_account_with_tenant()
parser.add_argument("email", type=email, required=True, location="json") parser = (
parser.add_argument("language", type=str, required=False, location="json") reqparse.RequestParser()
parser.add_argument("phase", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
parser.add_argument("token", type=str, required=False, location="json") .add_argument("language", type=str, required=False, location="json")
.add_argument("phase", type=str, required=False, location="json")
.add_argument("token", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
@ -467,8 +456,6 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError() raise InvalidTokenError()
user_email = reset_data.get("email", "") user_email = reset_data.get("email", "")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if user_email != current_user.email: if user_email != current_user.email:
raise InvalidEmailError() raise InvalidEmailError()
else: else:
@ -490,10 +477,12 @@ class ChangeEmailCheckApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") .add_argument("email", type=email, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
@ -533,9 +522,11 @@ class ChangeEmailResetApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("new_email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("new_email", type=email, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if AccountService.is_account_in_freeze(args["new_email"]): if AccountService.is_account_in_freeze(args["new_email"]):
@ -551,8 +542,7 @@ class ChangeEmailResetApi(Resource):
AccountService.revoke_change_email_token(args["token"]) AccountService.revoke_change_email_token(args["token"])
old_email = reset_data.get("old_email", "") old_email = reset_data.get("old_email", "")
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
if current_user.email != old_email: if current_user.email != old_email:
raise AccountNotFound() raise AccountNotFound()
@ -569,8 +559,7 @@ class ChangeEmailResetApi(Resource):
class CheckEmailUnique(Resource): class CheckEmailUnique(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
parser.add_argument("email", type=email, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
if AccountService.is_account_in_freeze(args["email"]): if AccountService.is_account_in_freeze(args["email"]):
raise AccountInFreezeError() raise AccountInFreezeError()

View File

@ -3,8 +3,7 @@ from flask_restx import Resource, fields
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from services.agent_service import AgentService from services.agent_service import AgentService
@ -21,12 +20,11 @@ class AgentProviderListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
user = current_user user = current_user
assert user.current_tenant_id is not None
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id tenant_id = current_tenant_id
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id)) return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
@ -45,9 +43,5 @@ class AgentProviderApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider_name: str): def get(self, provider_name: str):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
user = current_user return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))
assert user.current_tenant_id is not None
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))

View File

@ -37,10 +37,12 @@ class EndpointCreateApi(Resource):
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("plugin_unique_identifier", type=str, required=True) reqparse.RequestParser()
parser.add_argument("settings", type=dict, required=True) .add_argument("plugin_unique_identifier", type=str, required=True)
parser.add_argument("name", type=str, required=True) .add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args() args = parser.parse_args()
plugin_unique_identifier = args["plugin_unique_identifier"] plugin_unique_identifier = args["plugin_unique_identifier"]
@ -81,9 +83,11 @@ class EndpointListApi(Resource):
def get(self): def get(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=int, required=True, location="args") reqparse.RequestParser()
parser.add_argument("page_size", type=int, required=True, location="args") .add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()
page = args["page"] page = args["page"]
@ -124,10 +128,12 @@ class EndpointListForSinglePluginApi(Resource):
def get(self): def get(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=int, required=True, location="args") reqparse.RequestParser()
parser.add_argument("page_size", type=int, required=True, location="args") .add_argument("page", type=int, required=True, location="args")
parser.add_argument("plugin_id", type=str, required=True, location="args") .add_argument("page_size", type=int, required=True, location="args")
.add_argument("plugin_id", type=str, required=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()
page = args["page"] page = args["page"]
@ -166,8 +172,7 @@ class EndpointDeleteApi(Resource):
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
@ -206,10 +211,12 @@ class EndpointUpdateApi(Resource):
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("endpoint_id", type=str, required=True) reqparse.RequestParser()
parser.add_argument("settings", type=dict, required=True) .add_argument("endpoint_id", type=str, required=True)
parser.add_argument("name", type=str, required=True) .add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args() args = parser.parse_args()
endpoint_id = args["endpoint_id"] endpoint_id = args["endpoint_id"]
@ -249,8 +256,7 @@ class EndpointEnableApi(Resource):
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()
endpoint_id = args["endpoint_id"] endpoint_id = args["endpoint_id"]
@ -282,8 +288,7 @@ class EndpointDisableApi(Resource):
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()
endpoint_id = args["endpoint_id"] endpoint_id = args["endpoint_id"]

View File

@ -5,8 +5,8 @@ from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService from services.model_load_balancing_service import ModelLoadBalancingService
@ -18,24 +18,25 @@ class LoadBalancingCredentialsValidateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role): if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id tenant_id = current_tenant_id
assert tenant_id is not None
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
) )
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
# validate model load balancing credentials # validate model load balancing credentials
@ -72,24 +73,25 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str, config_id: str): def post(self, provider: str, config_id: str):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role): if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id tenant_id = current_tenant_id
assert tenant_id is not None
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
) )
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
# validate model load balancing config credentials # validate model load balancing config credentials

View File

@ -57,10 +57,12 @@ class MemberInviteEmailApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("members") @cloud_edition_billing_resource_check("members")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("emails", type=list, required=True, location="json") reqparse.RequestParser()
parser.add_argument("role", type=str, required=True, default="admin", location="json") .add_argument("emails", type=list, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json") .add_argument("role", type=str, required=True, default="admin", location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
invitee_emails = args["emails"] invitee_emails = args["emails"]
@ -149,8 +151,7 @@ class MemberUpdateRoleApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self, member_id): def put(self, member_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
parser.add_argument("role", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
new_role = args["role"] new_role = args["role"]
@ -199,8 +200,7 @@ class SendOwnerTransferEmailApi(Resource):
@account_initialization_required @account_initialization_required
@is_allow_transfer_owner @is_allow_transfer_owner
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
@ -236,9 +236,11 @@ class OwnerTransferCheckApi(Resource):
@account_initialization_required @account_initialization_required
@is_allow_transfer_owner @is_allow_transfer_owner
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("code", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
@ -281,8 +283,9 @@ class OwnerTransfer(Resource):
@account_initialization_required @account_initialization_required
@is_allow_transfer_owner @is_allow_transfer_owner
def post(self, member_id): def post(self, member_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("token", type=str, required=True, nullable=False, location="json") "token", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace

View File

@ -24,8 +24,7 @@ class ModelProviderListApi(Resource):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id tenant_id = current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"model_type", "model_type",
type=str, type=str,
required=False, required=False,
@ -50,8 +49,9 @@ class ModelProviderCredentialApi(Resource):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential # if credential_id is not provided, return current used credential
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") "credential_id", type=uuid_value, required=False, nullable=True, location="args"
)
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -69,9 +69,11 @@ class ModelProviderCredentialApi(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -96,10 +98,12 @@ class ModelProviderCredentialApi(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -124,8 +128,9 @@ class ModelProviderCredentialApi(Resource):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") "credential_id", type=uuid_value, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -145,8 +150,9 @@ class ModelProviderCredentialSwitchApi(Resource):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") "credential_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
service = ModelProviderService() service = ModelProviderService()
@ -165,8 +171,9 @@ class ModelProviderValidateApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") "credentials", type=dict, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
tenant_id = current_tenant_id tenant_id = current_tenant_id
@ -223,8 +230,7 @@ class PreferredProviderTypeUpdateApi(Resource):
tenant_id = current_tenant_id tenant_id = current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"preferred_provider_type", "preferred_provider_type",
type=str, type=str,
required=True, required=True,

View File

@ -24,8 +24,7 @@ class DefaultModelApi(Resource):
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"model_type", "model_type",
type=str, type=str,
required=True, required=True,
@ -51,8 +50,9 @@ class DefaultModelApi(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") "model_settings", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_settings = args["model_settings"] model_settings = args["model_settings"]
@ -107,19 +107,21 @@ class ModelProviderModelApi(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
) )
parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
if args.get("config_from", "") == "custom-model": if args.get("config_from", "") == "custom-model":
@ -167,15 +169,17 @@ class ModelProviderModelApi(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
) )
args = parser.parse_args() args = parser.parse_args()
@ -195,18 +199,20 @@ class ModelProviderModelCredentialApi(Resource):
def get(self, provider: str): def get(self, provider: str):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="args") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="args")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="args", choices=[mt.value for mt in ModelType],
location="args",
)
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
) )
parser.add_argument("config_from", type=str, required=False, nullable=True, location="args")
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -260,18 +266,20 @@ class ModelProviderModelCredentialApi(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
) )
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -305,19 +313,21 @@ class ModelProviderModelCredentialApi(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
) )
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -345,17 +355,19 @@ class ModelProviderModelCredentialApi(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
) )
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -380,17 +392,19 @@ class ModelProviderModelCredentialSwitchApi(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
) )
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
service = ModelProviderService() service = ModelProviderService()
@ -414,15 +428,17 @@ class ModelProviderModelEnableApi(Resource):
def patch(self, provider: str): def patch(self, provider: str):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
) )
args = parser.parse_args() args = parser.parse_args()
@ -444,15 +460,17 @@ class ModelProviderModelDisableApi(Resource):
def patch(self, provider: str): def patch(self, provider: str):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
) )
args = parser.parse_args() args = parser.parse_args()
@ -472,17 +490,19 @@ class ModelProviderModelValidateApi(Resource):
def post(self, provider: str): def post(self, provider: str):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
) )
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -516,8 +536,9 @@ class ModelProviderModelParameterRuleApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider: str): def get(self, provider: str):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("model", type=str, required=True, nullable=False, location="args") "model", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args() args = parser.parse_args()
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()

View File

@ -44,9 +44,11 @@ class PluginListApi(Resource):
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=int, required=False, location="args", default=1) reqparse.RequestParser()
parser.add_argument("page_size", type=int, required=False, location="args", default=256) .add_argument("page", type=int, required=False, location="args", default=1)
.add_argument("page_size", type=int, required=False, location="args", default=256)
)
args = parser.parse_args() args = parser.parse_args()
try: try:
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"]) plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
@ -62,8 +64,7 @@ class PluginListLatestVersionsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
req = reqparse.RequestParser() req = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
req.add_argument("plugin_ids", type=list, required=True, location="json")
args = req.parse_args() args = req.parse_args()
try: try:
@ -82,8 +83,7 @@ class PluginListInstallationsFromIdsApi(Resource):
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
parser.add_argument("plugin_ids", type=list, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -98,9 +98,11 @@ class PluginListInstallationsFromIdsApi(Resource):
class PluginIconApi(Resource): class PluginIconApi(Resource):
@setup_required @setup_required
def get(self): def get(self):
req = reqparse.RequestParser() req = (
req.add_argument("tenant_id", type=str, required=True, location="args") reqparse.RequestParser()
req.add_argument("filename", type=str, required=True, location="args") .add_argument("tenant_id", type=str, required=True, location="args")
.add_argument("filename", type=str, required=True, location="args")
)
args = req.parse_args() args = req.parse_args()
try: try:
@ -145,10 +147,12 @@ class PluginUploadFromGithubApi(Resource):
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("repo", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("version", type=str, required=True, location="json") .add_argument("repo", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json") .add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -192,8 +196,9 @@ class PluginInstallFromPkgApi(Resource):
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") "plugin_unique_identifiers", type=list, required=True, location="json"
)
args = parser.parse_args() args = parser.parse_args()
# check if all plugin_unique_identifiers are valid string # check if all plugin_unique_identifiers are valid string
@ -218,11 +223,13 @@ class PluginInstallFromGithubApi(Resource):
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("repo", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("version", type=str, required=True, location="json") .add_argument("repo", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json") .add_argument("version", type=str, required=True, location="json")
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json") .add_argument("package", type=str, required=True, location="json")
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -248,8 +255,9 @@ class PluginInstallFromMarketplaceApi(Resource):
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") "plugin_unique_identifiers", type=list, required=True, location="json"
)
args = parser.parse_args() args = parser.parse_args()
# check if all plugin_unique_identifiers are valid string # check if all plugin_unique_identifiers are valid string
@ -274,8 +282,9 @@ class PluginFetchMarketplacePkgApi(Resource):
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") "plugin_unique_identifier", type=str, required=True, location="args"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -300,8 +309,9 @@ class PluginFetchManifestApi(Resource):
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") "plugin_unique_identifier", type=str, required=True, location="args"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -325,9 +335,11 @@ class PluginFetchInstallTasksApi(Resource):
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=int, required=True, location="args") reqparse.RequestParser()
parser.add_argument("page_size", type=int, required=True, location="args") .add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -407,9 +419,11 @@ class PluginUpgradeFromMarketplaceApi(Resource):
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") .add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -431,12 +445,14 @@ class PluginUpgradeFromGithubApi(Resource):
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") .add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("repo", type=str, required=True, location="json") .add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json") .add_argument("repo", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json") .add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -461,8 +477,7 @@ class PluginUninstallApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
req = reqparse.RequestParser() req = reqparse.RequestParser().add_argument("plugin_installation_id", type=str, required=True, location="json")
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args() args = req.parse_args()
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
@ -484,9 +499,11 @@ class PluginChangePermissionApi(Resource):
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
req = reqparse.RequestParser() req = (
req.add_argument("install_permission", type=str, required=True, location="json") reqparse.RequestParser()
req.add_argument("debug_permission", type=str, required=True, location="json") .add_argument("install_permission", type=str, required=True, location="json")
.add_argument("debug_permission", type=str, required=True, location="json")
)
args = req.parse_args() args = req.parse_args()
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"]) install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
@ -535,12 +552,14 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
user_id = current_user.id user_id = current_user.id
parser = reqparse.RequestParser() parser = (
parser.add_argument("plugin_id", type=str, required=True, location="args") reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, location="args") .add_argument("plugin_id", type=str, required=True, location="args")
parser.add_argument("action", type=str, required=True, location="args") .add_argument("provider", type=str, required=True, location="args")
parser.add_argument("parameter", type=str, required=True, location="args") .add_argument("action", type=str, required=True, location="args")
parser.add_argument("provider_type", type=str, required=True, location="args") .add_argument("parameter", type=str, required=True, location="args")
.add_argument("provider_type", type=str, required=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -569,9 +588,11 @@ class PluginChangePreferencesApi(Resource):
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
req = reqparse.RequestParser() req = (
req.add_argument("permission", type=dict, required=True, location="json") reqparse.RequestParser()
req.add_argument("auto_upgrade", type=dict, required=True, location="json") .add_argument("permission", type=dict, required=True, location="json")
.add_argument("auto_upgrade", type=dict, required=True, location="json")
)
args = req.parse_args() args = req.parse_args()
permission = args["permission"] permission = args["permission"]
@ -661,8 +682,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
# exclude one single plugin # exclude one single plugin
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
req = reqparse.RequestParser() req = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
req.add_argument("plugin_id", type=str, required=True, location="json")
args = req.parse_args() args = req.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])}) return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})

View File

@ -60,8 +60,7 @@ class ToolProviderListApi(Resource):
user_id = user.id user_id = user.id
req = reqparse.RequestParser() req = reqparse.RequestParser().add_argument(
req.add_argument(
"type", "type",
type=str, type=str,
choices=["builtin", "model", "api", "workflow", "mcp"], choices=["builtin", "model", "api", "workflow", "mcp"],
@ -111,8 +110,9 @@ class ToolBuiltinProviderDeleteApi(Resource):
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
req = reqparse.RequestParser() req = reqparse.RequestParser().add_argument(
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json") "credential_id", type=str, required=True, nullable=False, location="json"
)
args = req.parse_args() args = req.parse_args()
return BuiltinToolManageService.delete_builtin_tool_provider( return BuiltinToolManageService.delete_builtin_tool_provider(
@ -132,10 +132,12 @@ class ToolBuiltinProviderAddApi(Resource):
user_id = user.id user_id = user.id
parser = reqparse.RequestParser() parser = (
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json") .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("type", type=str, required=True, nullable=False, location="json") .add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
.add_argument("type", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if args["type"] not in CredentialType.values(): if args["type"] not in CredentialType.values():
@ -164,10 +166,12 @@ class ToolBuiltinProviderUpdateApi(Resource):
user_id = user.id user_id = user.id
parser = reqparse.RequestParser() parser = (
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") .add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
@ -220,15 +224,17 @@ class ToolApiProviderAddApi(Resource):
user_id = user.id user_id = user.id
parser = reqparse.RequestParser() parser = (
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("schema", type=str, required=True, nullable=False, location="json") .add_argument("schema_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider", type=str, required=True, nullable=False, location="json") .add_argument("schema", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") .add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") .add_argument("icon", type=dict, required=True, nullable=False, location="json")
parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") .add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
@ -256,9 +262,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
user_id = user.id user_id = user.id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
parser.add_argument("url", type=str, required=True, nullable=False, location="args")
args = parser.parse_args() args = parser.parse_args()
@ -279,9 +283,9 @@ class ToolApiProviderListToolsApi(Resource):
user_id = user.id user_id = user.id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
parser.add_argument("provider", type=str, required=True, nullable=False, location="args") )
args = parser.parse_args() args = parser.parse_args()
@ -307,16 +311,18 @@ class ToolApiProviderUpdateApi(Resource):
user_id = user.id user_id = user.id
parser = reqparse.RequestParser() parser = (
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("schema", type=str, required=True, nullable=False, location="json") .add_argument("schema_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider", type=str, required=True, nullable=False, location="json") .add_argument("schema", type=str, required=True, nullable=False, location="json")
parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json") .add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") .add_argument("original_provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") .add_argument("icon", type=dict, required=True, nullable=False, location="json")
parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") .add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") .add_argument("labels", type=list[str], required=False, nullable=True, location="json")
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
@ -348,9 +354,9 @@ class ToolApiProviderDeleteApi(Resource):
user_id = user.id user_id = user.id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="json"
parser.add_argument("provider", type=str, required=True, nullable=False, location="json") )
args = parser.parse_args() args = parser.parse_args()
@ -371,9 +377,9 @@ class ToolApiProviderGetApi(Resource):
user_id = user.id user_id = user.id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
parser.add_argument("provider", type=str, required=True, nullable=False, location="args") )
args = parser.parse_args() args = parser.parse_args()
@ -405,9 +411,9 @@ class ToolApiProviderSchemaApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
"schema", type=str, required=True, nullable=False, location="json"
parser.add_argument("schema", type=str, required=True, nullable=False, location="json") )
args = parser.parse_args() args = parser.parse_args()
@ -422,14 +428,15 @@ class ToolApiProviderPreviousTestApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
reqparse.RequestParser()
parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json") .add_argument("tool_name", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json") .add_argument("provider_name", type=str, required=False, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json") .add_argument("parameters", type=dict, required=True, nullable=False, location="json")
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") .add_argument("schema_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("schema", type=str, required=True, nullable=False, location="json") .add_argument("schema", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
@ -457,15 +464,17 @@ class ToolWorkflowProviderCreateApi(Resource):
user_id = user.id user_id = user.id
reqparser = reqparse.RequestParser() reqparser = (
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") reqparse.RequestParser()
reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") .add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") .add_argument("label", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") .add_argument("description", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") .add_argument("icon", type=dict, required=True, nullable=False, location="json")
reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
args = reqparser.parse_args() args = reqparser.parse_args()
@ -496,15 +505,17 @@ class ToolWorkflowProviderUpdateApi(Resource):
user_id = user.id user_id = user.id
reqparser = reqparse.RequestParser() reqparser = (
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") reqparse.RequestParser()
reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") .add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") .add_argument("label", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") .add_argument("description", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") .add_argument("icon", type=dict, required=True, nullable=False, location="json")
reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
args = reqparser.parse_args() args = reqparser.parse_args()
@ -538,8 +549,9 @@ class ToolWorkflowProviderDeleteApi(Resource):
user_id = user.id user_id = user.id
reqparser = reqparse.RequestParser() reqparser = reqparse.RequestParser().add_argument(
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
)
args = reqparser.parse_args() args = reqparser.parse_args()
@ -560,9 +572,11 @@ class ToolWorkflowProviderGetApi(Resource):
user_id = user.id user_id = user.id
parser = reqparse.RequestParser() parser = (
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") reqparse.RequestParser()
parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") .add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()
@ -594,8 +608,9 @@ class ToolWorkflowProviderListToolApi(Resource):
user_id = user.id user_id = user.id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
)
args = parser.parse_args() args = parser.parse_args()
@ -780,8 +795,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
return BuiltinToolManageService.set_default_provider( return BuiltinToolManageService.set_default_provider(
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
@ -794,9 +808,11 @@ class ToolOAuthCustomClient(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
parser = reqparse.RequestParser() parser = (
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") reqparse.RequestParser()
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") .add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
@ -866,16 +882,18 @@ class ToolProviderMCPApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") .add_argument("server_url", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json") .add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") .add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") .add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") .add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) .add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
parser.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args() args = parser.parse_args()
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
@ -909,18 +927,19 @@ class ToolProviderMCPApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self): def put(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") .add_argument("server_url", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json") .add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") .add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") .add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") .add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") .add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("configuration", type=dict, required=False, nullable=True, location="json") .add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("authentication", type=dict, required=False, nullable=True, location="json") .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
)
args = parser.parse_args() args = parser.parse_args()
if not is_valid_url(args["server_url"]): if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]: if "[__HIDDEN__]" in args["server_url"]:
@ -952,8 +971,9 @@ class ToolProviderMCPApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self): def delete(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") "provider_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
@ -969,9 +989,11 @@ class ToolMCPAuthApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json") .add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
provider_id = args["provider_id"] provider_id = args["provider_id"]
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
@ -1066,9 +1088,11 @@ class ToolMCPUpdateApi(Resource):
@console_ns.route("/mcp/oauth/callback") @console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource): class ToolMCPCallbackApi(Resource):
def get(self): def get(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("code", type=str, required=True, nullable=False, location="args") reqparse.RequestParser()
parser.add_argument("state", type=str, required=True, nullable=False, location="args") .add_argument("code", type=str, required=True, nullable=False, location="args")
.add_argument("state", type=str, required=True, nullable=False, location="args")
)
args = parser.parse_args() args = parser.parse_args()
state_key = args["state"] state_key = args["state"]
authorization_code = args["code"] authorization_code = args["code"]

View File

@ -23,8 +23,8 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from libs.helper import TimestampField
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account, Tenant, TenantStatus from models.account import Tenant, TenantStatus
from services.account_service import TenantService from services.account_service import TenantService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.file_service import FileService from services.file_service import FileService
@ -70,8 +70,7 @@ class TenantListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
tenants = TenantService.get_join_tenants(current_user) tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = [] tenant_dicts = []
@ -85,7 +84,7 @@ class TenantListApi(Resource):
"status": tenant.status, "status": tenant.status,
"created_at": tenant.created_at, "created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False, "current": tenant.id == current_tenant_id if current_tenant_id else False,
} }
tenant_dicts.append(tenant_dict) tenant_dicts.append(tenant_dict)
@ -98,9 +97,11 @@ class WorkspaceListApi(Resource):
@setup_required @setup_required
@admin_required @admin_required
def get(self): def get(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args() args = parser.parse_args()
stmt = select(Tenant).order_by(Tenant.created_at.desc()) stmt = select(Tenant).order_by(Tenant.created_at.desc())
@ -130,8 +131,7 @@ class TenantApi(Resource):
if request.path == "/info": if request.path == "/info":
logger.warning("Deprecated URL /info was used.") logger.warning("Deprecated URL /info was used.")
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
tenant = current_user.current_tenant tenant = current_user.current_tenant
if not tenant: if not tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
@ -155,10 +155,8 @@ class SwitchWorkspaceApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
# check if tenant_id is valid, 403 if not # check if tenant_id is valid, 403 if not
@ -181,16 +179,14 @@ class CustomConfigWorkspaceApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom") @cloud_edition_billing_resource_check("workspace_custom")
def post(self): def post(self):
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account") parser = (
parser = reqparse.RequestParser() reqparse.RequestParser()
parser.add_argument("remove_webapp_brand", type=bool, location="json") .add_argument("remove_webapp_brand", type=bool, location="json")
parser.add_argument("replace_webapp_logo", type=str, location="json") .add_argument("replace_webapp_logo", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
tenant = db.get_or_404(Tenant, current_tenant_id)
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
custom_config_dict = { custom_config_dict = {
"remove_webapp_brand": args["remove_webapp_brand"], "remove_webapp_brand": args["remove_webapp_brand"],
@ -212,8 +208,7 @@ class WebappLogoWorkspaceApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom") @cloud_edition_billing_resource_check("workspace_custom")
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
# check file # check file
if "file" not in request.files: if "file" not in request.files:
raise NoFileUploadedError() raise NoFileUploadedError()
@ -253,15 +248,13 @@ class WorkspaceInfoApi(Resource):
@account_initialization_required @account_initialization_required
# Change workspace name # Change workspace name
def post(self): def post(self):
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id: if not current_tenant_id:
raise ValueError("No current tenant") raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id) tenant = db.get_or_404(Tenant, current_tenant_id)
tenant.name = args["name"] tenant.name = args["name"]
db.session.commit() db.session.commit()

View File

@ -30,10 +30,7 @@ def account_initialization_required(view: Callable[P, R]):
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization # check account initialization
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if current_user.status == AccountStatus.UNINITIALIZED:
account = current_user
if account.status == AccountStatus.UNINITIALIZED:
raise AccountNotInitializedError() raise AccountNotInitializedError()
return view(*args, **kwargs) return view(*args, **kwargs)
@ -249,9 +246,9 @@ def email_password_login_enabled(view: Callable[P, R]):
return decorated return decorated
def email_register_enabled(view): def email_register_enabled(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
if features.is_allow_register: if features.is_allow_register:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -299,3 +296,21 @@ def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
abort(403) abort(403)
return decorated return decorated
def edit_permission_required(f: Callable[P, R]):
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
from libs.login import current_user
from models import Account
user = current_user._get_current_object() # type: ignore
if not isinstance(user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)
return decorated_function

View File

@ -46,11 +46,13 @@ class FilePreviewApi(Resource):
def get(self, file_id): def get(self, file_id):
file_id = str(file_id) file_id = str(file_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("timestamp", type=str, required=True, location="args") reqparse.RequestParser()
parser.add_argument("nonce", type=str, required=True, location="args") .add_argument("timestamp", type=str, required=True, location="args")
parser.add_argument("sign", type=str, required=True, location="args") .add_argument("nonce", type=str, required=True, location="args")
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") .add_argument("sign", type=str, required=True, location="args")
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args() args = parser.parse_args()

View File

@ -16,12 +16,13 @@ class ToolFileApi(Resource):
def get(self, file_id, extension): def get(self, file_id, extension):
file_id = str(file_id) file_id = str(file_id)
parser = reqparse.RequestParser() parser = (
reqparse.RequestParser()
parser.add_argument("timestamp", type=str, required=True, location="args") .add_argument("timestamp", type=str, required=True, location="args")
parser.add_argument("nonce", type=str, required=True, location="args") .add_argument("nonce", type=str, required=True, location="args")
parser.add_argument("sign", type=str, required=True, location="args") .add_argument("sign", type=str, required=True, location="args")
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") .add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args() args = parser.parse_args()
if not verify_tool_file_signature( if not verify_tool_file_signature(

View File

@ -18,19 +18,17 @@ from core.tools.tool_file_manager import ToolFileManager
from fields.file_fields import build_file_model from fields.file_fields import build_file_model
# Define parser for both documentation and validation # Define parser for both documentation and validation
upload_parser = reqparse.RequestParser() upload_parser = (
upload_parser.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload") reqparse.RequestParser()
upload_parser.add_argument( .add_argument("file", location="files", type=FileStorage, required=True, help="File to upload")
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification" .add_argument(
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification"
)
.add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification")
.add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation")
.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
) )
upload_parser.add_argument(
"nonce", type=str, required=True, location="args", help="Random string for signature verification"
)
upload_parser.add_argument(
"sign", type=str, required=True, location="args", help="HMAC signature for request validation"
)
upload_parser.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
upload_parser.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
@files_ns.route("/upload/for-plugin") @files_ns.route("/upload/for-plugin")

View File

@ -5,11 +5,13 @@ from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
from tasks.mail_inner_task import send_inner_email_task from tasks.mail_inner_task import send_inner_email_task
_mail_parser = reqparse.RequestParser() _mail_parser = (
_mail_parser.add_argument("to", type=str, action="append", required=True) reqparse.RequestParser()
_mail_parser.add_argument("subject", type=str, required=True) .add_argument("to", type=str, action="append", required=True)
_mail_parser.add_argument("body", type=str, required=True) .add_argument("subject", type=str, required=True)
_mail_parser.add_argument("substitutions", type=dict, required=False) .add_argument("body", type=str, required=True)
.add_argument("substitutions", type=dict, required=False)
)
class BaseMail(Resource): class BaseMail(Resource):
@ -17,7 +19,7 @@ class BaseMail(Resource):
def post(self): def post(self):
args = _mail_parser.parse_args() args = _mail_parser.parse_args()
send_inner_email_task.delay( send_inner_email_task.delay( # type: ignore
to=args["to"], to=args["to"],
subject=args["subject"], subject=args["subject"],
body=args["body"], body=args["body"],

View File

@ -31,7 +31,7 @@ from core.plugin.entities.request import (
) )
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from libs.helper import length_prefixed_response from libs.helper import length_prefixed_response
from models.account import Account, Tenant from models import Account, Tenant
from models.model import EndUser from models.model import EndUser

View File

@ -72,9 +72,11 @@ def get_user_tenant(view: Callable[P, R] | None = None):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs): def decorated_view(*args: P.args, **kwargs: P.kwargs):
# fetch json body # fetch json body
parser = reqparse.RequestParser() parser = (
parser.add_argument("tenant_id", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("user_id", type=str, required=True, location="json") .add_argument("tenant_id", type=str, required=True, location="json")
.add_argument("user_id", type=str, required=True, location="json")
)
p = parser.parse_args() p = parser.parse_args()

View File

@ -7,7 +7,7 @@ from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import enterprise_inner_api_only from controllers.inner_api.wraps import enterprise_inner_api_only
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models import Account
from services.account_service import TenantService from services.account_service import TenantService
@ -25,9 +25,11 @@ class EnterpriseWorkspace(Resource):
} }
) )
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("name", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("owner_email", type=str, required=True, location="json") .add_argument("name", type=str, required=True, location="json")
.add_argument("owner_email", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
account = db.session.query(Account).filter_by(email=args["owner_email"]).first() account = db.session.query(Account).filter_by(email=args["owner_email"]).first()
@ -68,8 +70,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
} }
) )
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)

View File

@ -33,14 +33,12 @@ def int_or_str(value):
# Define parser for both documentation and validation # Define parser for both documentation and validation
mcp_request_parser = reqparse.RequestParser() mcp_request_parser = (
mcp_request_parser.add_argument( reqparse.RequestParser()
"jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')" .add_argument("jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')")
) .add_argument("method", type=str, required=True, location="json", help="The method to invoke")
mcp_request_parser.add_argument("method", type=str, required=True, location="json", help="The method to invoke") .add_argument("params", type=dict, required=False, location="json", help="Parameters for the method")
mcp_request_parser.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method") .add_argument("id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses")
mcp_request_parser.add_argument(
"id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses"
) )

View File

@ -10,24 +10,24 @@ from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_fields, build_annotation_model from fields.annotation_fields import annotation_fields, build_annotation_model
from libs.login import current_user from libs.login import current_user
from models.account import Account from models import Account
from models.model import App from models.model import App
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
# Define parsers for annotation API # Define parsers for annotation API
annotation_create_parser = reqparse.RequestParser() annotation_create_parser = (
annotation_create_parser.add_argument("question", required=True, type=str, location="json", help="Annotation question") reqparse.RequestParser()
annotation_create_parser.add_argument("answer", required=True, type=str, location="json", help="Annotation answer") .add_argument("question", required=True, type=str, location="json", help="Annotation question")
.add_argument("answer", required=True, type=str, location="json", help="Annotation answer")
)
annotation_reply_action_parser = reqparse.RequestParser() annotation_reply_action_parser = (
annotation_reply_action_parser.add_argument( reqparse.RequestParser()
"score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching" .add_argument(
) "score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching"
annotation_reply_action_parser.add_argument( )
"embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name" .add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name")
) .add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name")
annotation_reply_action_parser.add_argument(
"embedding_model_name", required=True, type=str, location="json", help="Embedding model name"
) )

View File

@ -85,11 +85,13 @@ class AudioApi(Resource):
# Define parser for text-to-audio API # Define parser for text-to-audio API
text_to_audio_parser = reqparse.RequestParser() text_to_audio_parser = (
text_to_audio_parser.add_argument("message_id", type=str, required=False, location="json", help="Message ID") reqparse.RequestParser()
text_to_audio_parser.add_argument("voice", type=str, location="json", help="Voice to use for TTS") .add_argument("message_id", type=str, required=False, location="json", help="Message ID")
text_to_audio_parser.add_argument("text", type=str, location="json", help="Text to convert to audio") .add_argument("voice", type=str, location="json", help="Voice to use for TTS")
text_to_audio_parser.add_argument("streaming", type=bool, location="json", help="Enable streaming response") .add_argument("text", type=str, location="json", help="Text to convert to audio")
.add_argument("streaming", type=bool, location="json", help="Enable streaming response")
)
@service_api_ns.route("/text-to-audio") @service_api_ns.route("/text-to-audio")

View File

@ -37,40 +37,34 @@ logger = logging.getLogger(__name__)
# Define parser for completion API # Define parser for completion API
completion_parser = reqparse.RequestParser() completion_parser = (
completion_parser.add_argument( reqparse.RequestParser()
"inputs", type=dict, required=True, location="json", help="Input parameters for completion" .add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion")
) .add_argument("query", type=str, location="json", default="", help="The query string")
completion_parser.add_argument("query", type=str, location="json", default="", help="The query string") .add_argument("files", type=list, required=False, location="json", help="List of file attachments")
completion_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
completion_parser.add_argument( .add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
"response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode"
)
completion_parser.add_argument(
"retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source"
) )
# Define parser for chat API # Define parser for chat API
chat_parser = reqparse.RequestParser() chat_parser = (
chat_parser.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat") reqparse.RequestParser()
chat_parser.add_argument("query", type=str, required=True, location="json", help="The chat query") .add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat")
chat_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") .add_argument("query", type=str, required=True, location="json", help="The chat query")
chat_parser.add_argument( .add_argument("files", type=list, required=False, location="json", help="List of file attachments")
"response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode" .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
.add_argument(
"auto_generate_name",
type=bool,
required=False,
default=True,
location="json",
help="Auto generate conversation name",
)
.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
) )
chat_parser.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
chat_parser.add_argument(
"retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source"
)
chat_parser.add_argument(
"auto_generate_name",
type=bool,
required=False,
default=True,
location="json",
help="Auto generate conversation name",
)
chat_parser.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
@service_api_ns.route("/completion-messages") @service_api_ns.route("/completion-messages")

View File

@ -24,48 +24,63 @@ from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
# Define parsers for conversation APIs # Define parsers for conversation APIs
conversation_list_parser = reqparse.RequestParser() conversation_list_parser = (
conversation_list_parser.add_argument( reqparse.RequestParser()
"last_id", type=uuid_value, location="args", help="Last conversation ID for pagination" .add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination")
) .add_argument(
conversation_list_parser.add_argument( "limit",
"limit", type=int_range(1, 100),
type=int_range(1, 100), required=False,
required=False, default=20,
default=20, location="args",
location="args", help="Number of conversations to return",
help="Number of conversations to return", )
) .add_argument(
conversation_list_parser.add_argument( "sort_by",
"sort_by", type=str,
type=str, choices=["created_at", "-created_at", "updated_at", "-updated_at"],
choices=["created_at", "-created_at", "updated_at", "-updated_at"], required=False,
required=False, default="-updated_at",
default="-updated_at", location="args",
location="args", help="Sort order for conversations",
help="Sort order for conversations", )
) )
conversation_rename_parser = reqparse.RequestParser() conversation_rename_parser = (
conversation_rename_parser.add_argument("name", type=str, required=False, location="json", help="New conversation name") reqparse.RequestParser()
conversation_rename_parser.add_argument( .add_argument("name", type=str, required=False, location="json", help="New conversation name")
"auto_generate", type=bool, required=False, default=False, location="json", help="Auto-generate conversation name" .add_argument(
"auto_generate",
type=bool,
required=False,
default=False,
location="json",
help="Auto-generate conversation name",
)
) )
conversation_variables_parser = reqparse.RequestParser() conversation_variables_parser = (
conversation_variables_parser.add_argument( reqparse.RequestParser()
"last_id", type=uuid_value, location="args", help="Last variable ID for pagination" .add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination")
) .add_argument(
conversation_variables_parser.add_argument( "limit",
"limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of variables to return" type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of variables to return",
)
) )
conversation_variable_update_parser = reqparse.RequestParser() conversation_variable_update_parser = reqparse.RequestParser().add_argument(
# using lambda is for passing the already-typed value without modification # using lambda is for passing the already-typed value without modification
# if no lambda, it will be converted to string # if no lambda, it will be converted to string
# the string cannot be converted using json.loads # the string cannot be converted using json.loads
conversation_variable_update_parser.add_argument( "value",
"value", required=True, location="json", type=lambda x: x, help="New value for the conversation variable" required=True,
location="json",
type=lambda x: x,
help="New value for the conversation variable",
) )

View File

@ -18,8 +18,7 @@ logger = logging.getLogger(__name__)
# Define parser for file preview API # Define parser for file preview API
file_preview_parser = reqparse.RequestParser() file_preview_parser = reqparse.RequestParser().add_argument(
file_preview_parser.add_argument(
"as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment" "as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment"
) )

View File

@ -26,25 +26,37 @@ logger = logging.getLogger(__name__)
# Define parsers for message APIs # Define parsers for message APIs
message_list_parser = reqparse.RequestParser() message_list_parser = (
message_list_parser.add_argument( reqparse.RequestParser()
"conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID" .add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID")
) .add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination")
message_list_parser.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination") .add_argument(
message_list_parser.add_argument( "limit",
"limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of messages to return" type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of messages to return",
)
) )
message_feedback_parser = reqparse.RequestParser() message_feedback_parser = (
message_feedback_parser.add_argument( reqparse.RequestParser()
"rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating" .add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating")
.add_argument("content", type=str, location="json", help="Feedback content")
) )
message_feedback_parser.add_argument("content", type=str, location="json", help="Feedback content")
feedback_list_parser = reqparse.RequestParser() feedback_list_parser = (
feedback_list_parser.add_argument("page", type=int, default=1, location="args", help="Page number") reqparse.RequestParser()
feedback_list_parser.add_argument( .add_argument("page", type=int, default=1, location="args", help="Page number")
"limit", type=int_range(1, 101), required=False, default=20, location="args", help="Number of feedbacks per page" .add_argument(
"limit",
type=int_range(1, 101),
required=False,
default=20,
location="args",
help="Number of feedbacks per page",
)
) )

View File

@ -42,32 +42,36 @@ from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define parsers for workflow APIs # Define parsers for workflow APIs
workflow_run_parser = reqparse.RequestParser() workflow_run_parser = (
workflow_run_parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
workflow_run_parser.add_argument("files", type=list, required=False, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
workflow_run_parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") .add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
)
workflow_log_parser = reqparse.RequestParser() workflow_log_parser = (
workflow_log_parser.add_argument("keyword", type=str, location="args") reqparse.RequestParser()
workflow_log_parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") .add_argument("keyword", type=str, location="args")
workflow_log_parser.add_argument("created_at__before", type=str, location="args") .add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
workflow_log_parser.add_argument("created_at__after", type=str, location="args") .add_argument("created_at__before", type=str, location="args")
workflow_log_parser.add_argument( .add_argument("created_at__after", type=str, location="args")
"created_by_end_user_session_id", .add_argument(
type=str, "created_by_end_user_session_id",
location="args", type=str,
required=False, location="args",
default=None, required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
) )
workflow_log_parser.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
workflow_log_parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
workflow_log_parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
workflow_run_fields = { workflow_run_fields = {
"id": fields.String, "id": fields.String,

View File

@ -33,119 +33,118 @@ def _validate_name(name):
# Define parsers for dataset operations # Define parsers for dataset operations
dataset_create_parser = reqparse.RequestParser() dataset_create_parser = (
dataset_create_parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="type is required. Name must be between 1 to 40 characters.", required=True,
type=_validate_name, help="type is required. Name must be between 1 to 40 characters.",
) type=_validate_name,
dataset_create_parser.add_argument( )
"description", .add_argument(
type=validate_description_length, "description",
nullable=True, type=validate_description_length,
required=False, nullable=True,
default="", required=False,
) default="",
dataset_create_parser.add_argument( )
"indexing_technique", .add_argument(
type=str, "indexing_technique",
location="json", type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST, location="json",
help="Invalid indexing technique.", choices=Dataset.INDEXING_TECHNIQUE_LIST,
) help="Invalid indexing technique.",
dataset_create_parser.add_argument( )
"permission", .add_argument(
type=str, "permission",
location="json", type=str,
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), location="json",
help="Invalid permission.", choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
required=False, help="Invalid permission.",
nullable=False, required=False,
) nullable=False,
dataset_create_parser.add_argument( )
"external_knowledge_api_id", .add_argument(
type=str, "external_knowledge_api_id",
nullable=True, type=str,
required=False, nullable=True,
default="_validate_name", required=False,
) default="_validate_name",
dataset_create_parser.add_argument( )
"provider", .add_argument(
type=str, "provider",
nullable=True, type=str,
required=False, nullable=True,
default="vendor", required=False,
) default="vendor",
dataset_create_parser.add_argument( )
"external_knowledge_id", .add_argument(
type=str, "external_knowledge_id",
nullable=True, type=str,
required=False, nullable=True,
) required=False,
dataset_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") )
dataset_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") .add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
dataset_create_parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") .add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
dataset_update_parser = reqparse.RequestParser()
dataset_update_parser.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
dataset_update_parser.add_argument(
"description", location="json", store_missing=False, type=validate_description_length
)
dataset_update_parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
dataset_update_parser.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
dataset_update_parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
dataset_update_parser.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
dataset_update_parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
dataset_update_parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
dataset_update_parser.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
dataset_update_parser.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
dataset_update_parser.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
) )
tag_create_parser = reqparse.RequestParser() dataset_update_parser = (
tag_create_parser.add_argument( reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument("embedding_model_provider", type=str, location="json", help="Invalid embedding model provider.")
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
)
tag_create_parser = reqparse.RequestParser().add_argument(
"name", "name",
nullable=False, nullable=False,
required=True, required=True,
@ -155,32 +154,37 @@ tag_create_parser.add_argument(
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
) )
tag_update_parser = reqparse.RequestParser() tag_update_parser = (
tag_update_parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="Name must be between 1 to 50 characters.", required=True,
type=lambda x: x help="Name must be between 1 to 50 characters.",
if x and 1 <= len(x) <= 50 type=lambda x: x
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), if x and 1 <= len(x) <= 50
) else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
tag_update_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) )
.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
tag_delete_parser = reqparse.RequestParser()
tag_delete_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
tag_binding_parser = reqparse.RequestParser()
tag_binding_parser.add_argument(
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
)
tag_binding_parser.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
) )
tag_unbinding_parser = reqparse.RequestParser() tag_delete_parser = reqparse.RequestParser().add_argument(
tag_unbinding_parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") "tag_id", nullable=False, required=True, help="Id of a tag.", type=str
tag_unbinding_parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") )
tag_binding_parser = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
)
)
tag_unbinding_parser = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
)
@service_api_ns.route("/datasets") @service_api_ns.route("/datasets")

View File

@ -35,37 +35,31 @@ from services.entities.knowledge_entities.knowledge_entities import KnowledgeCon
from services.file_service import FileService from services.file_service import FileService
# Define parsers for document operations # Define parsers for document operations
document_text_create_parser = reqparse.RequestParser() document_text_create_parser = (
document_text_create_parser.add_argument("name", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
document_text_create_parser.add_argument("text", type=str, required=True, nullable=False, location="json") .add_argument("name", type=str, required=True, nullable=False, location="json")
document_text_create_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") .add_argument("text", type=str, required=True, nullable=False, location="json")
document_text_create_parser.add_argument("original_document_id", type=str, required=False, location="json") .add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
document_text_create_parser.add_argument( .add_argument("original_document_id", type=str, required=False, location="json")
"doc_form", type=str, default="text_model", required=False, nullable=False, location="json" .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
) .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
document_text_create_parser.add_argument( .add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json" "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
) )
document_text_create_parser.add_argument( .add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" .add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
) .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
document_text_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
document_text_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
document_text_create_parser.add_argument(
"embedding_model_provider", type=str, required=False, nullable=True, location="json"
) )
document_text_update_parser = reqparse.RequestParser() document_text_update_parser = (
document_text_update_parser.add_argument("name", type=str, required=False, nullable=True, location="json") reqparse.RequestParser()
document_text_update_parser.add_argument("text", type=str, required=False, nullable=True, location="json") .add_argument("name", type=str, required=False, nullable=True, location="json")
document_text_update_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") .add_argument("text", type=str, required=False, nullable=True, location="json")
document_text_update_parser.add_argument( .add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
"doc_form", type=str, default="text_model", required=False, nullable=False, location="json" .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
) )
document_text_update_parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
document_text_update_parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
@service_api_ns.route( @service_api_ns.route(

View File

@ -15,21 +15,17 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
# Define parsers for metadata APIs # Define parsers for metadata APIs
metadata_create_parser = reqparse.RequestParser() metadata_create_parser = (
metadata_create_parser.add_argument( reqparse.RequestParser()
"type", type=str, required=True, nullable=False, location="json", help="Metadata type" .add_argument("type", type=str, required=True, nullable=False, location="json", help="Metadata type")
) .add_argument("name", type=str, required=True, nullable=False, location="json", help="Metadata name")
metadata_create_parser.add_argument(
"name", type=str, required=True, nullable=False, location="json", help="Metadata name"
) )
metadata_update_parser = reqparse.RequestParser() metadata_update_parser = reqparse.RequestParser().add_argument(
metadata_update_parser.add_argument(
"name", type=str, required=True, nullable=False, location="json", help="New metadata name" "name", type=str, required=True, nullable=False, location="json", help="New metadata name"
) )
document_metadata_parser = reqparse.RequestParser() document_metadata_parser = reqparse.RequestParser().add_argument(
document_metadata_parser.add_argument(
"operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data" "operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data"
) )

Some files were not shown because too many files have changed in this diff Show More