merge main

This commit is contained in:
JzoNg 2025-10-20 14:21:09 +08:00
commit a4e2ef6b0c
804 changed files with 16368 additions and 10423 deletions

View File

@ -39,25 +39,11 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: uv sync --project api --dev run: uv sync --project api --dev
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Run pyrefly check - name: Run pyrefly check
run: | run: |
cd api cd api
uv add --dev pyrefly uv add --dev pyrefly
uv run pyrefly check || true uv run pyrefly check || true
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
- name: Run dify config tests - name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py run: uv run --project api dev/pytest/pytest_config_tests.py
@ -93,3 +79,19 @@ jobs:
- name: Run TestContainers - name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY

View File

@ -1,6 +1,7 @@
#!/bin/bash #!/bin/bash
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
@ -13,4 +14,4 @@ yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.ya
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss" echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"

View File

@ -14,7 +14,7 @@ The codebase is split into:
- Run backend CLI commands through `uv run --project api <command>`. - Run backend CLI commands through `uv run --project api <command>`.
- Backend QA gate requires passing `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before review. - Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks. - Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.

View File

@ -129,8 +129,18 @@ Star Dify on GitHub and be instantly notified of new releases.
## Advanced Setup ## Advanced Setup
### Custom configurations
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
### Metrics Monitoring with Grafana
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
- [Grafana Dashboard by @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard)
### Deployment with Kubernetes
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)

View File

@ -189,6 +189,11 @@ class PluginConfig(BaseSettings):
default="plugin-api-key", default="plugin-api-key",
) )
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=300.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key") INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
PLUGIN_REMOTE_INSTALL_HOST: str = Field( PLUGIN_REMOTE_INSTALL_HOST: str = Field(
@ -543,7 +548,7 @@ class UpdateConfig(BaseSettings):
class WorkflowVariableTruncationConfig(BaseSettings): class WorkflowVariableTruncationConfig(BaseSettings):
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field( WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
# 100KB # 1000 KiB
1024_000, 1024_000,
description="Maximum size for variable to trigger final truncation.", description="Maximum size for variable to trigger final truncation.",
) )

View File

@ -55,3 +55,12 @@ else:
"properties", "properties",
} }
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions) DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
COOKIE_NAME_ACCESS_TOKEN = "access_token"
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
COOKIE_NAME_PASSPORT = "passport"
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
HEADER_NAME_APP_CODE = "X-App-Code"
HEADER_NAME_PASSPORT = "X-App-Passport"

View File

@ -15,6 +15,7 @@ from constants.languages import supported_language
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.wraps import only_edition_cloud from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp from models.model import App, InstalledApp, RecommendedApp
@ -24,19 +25,9 @@ def admin_required(view: Callable[P, R]):
if not dify_config.ADMIN_API_KEY: if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.") raise Unauthorized("API key is invalid.")
auth_header = request.headers.get("Authorization") auth_token = extract_access_token(request)
if auth_header is None: if not auth_token:
raise Unauthorized("Authorization header is missing.") raise Unauthorized("Authorization header is missing.")
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if auth_token != dify_config.ADMIN_API_KEY: if auth_token != dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.") raise Unauthorized("API key is invalid.")
@ -70,15 +61,17 @@ class InsertExploreAppListApi(Resource):
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("desc", type=str, location="json") .add_argument("app_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("copyright", type=str, location="json") .add_argument("desc", type=str, location="json")
parser.add_argument("privacy_policy", type=str, location="json") .add_argument("copyright", type=str, location="json")
parser.add_argument("custom_disclaimer", type=str, location="json") .add_argument("privacy_policy", type=str, location="json")
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json") .add_argument("custom_disclaimer", type=str, location="json")
parser.add_argument("category", type=str, required=True, nullable=False, location="json") .add_argument("language", type=supported_language, required=True, nullable=False, location="json")
parser.add_argument("position", type=int, required=True, nullable=False, location="json") .add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("position", type=int, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()

View File

@ -7,13 +7,12 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from libs.helper import TimestampField
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from models.dataset import Dataset from models.dataset import Dataset
from models.model import ApiToken, App from models.model import ApiToken, App
from . import api, console_ns from . import api, console_ns
from .wraps import account_initialization_required, setup_required from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = { api_key_fields = {
"id": fields.String, "id": fields.String,
@ -57,9 +56,9 @@ class BaseApiKeyListResource(Resource):
def get(self, resource_id): def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_tenant_id, self.resource_model)
keys = db.session.scalars( keys = db.session.scalars(
select(ApiToken).where( select(ApiToken).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
@ -68,15 +67,12 @@ class BaseApiKeyListResource(Resource):
return {"items": keys} return {"items": keys}
@marshal_with(api_key_fields) @marshal_with(api_key_fields)
@edit_permission_required
def post(self, resource_id): def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None _get_resource(resource_id, current_tenant_id, self.resource_model)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.has_edit_permission:
raise Forbidden()
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
@ -93,7 +89,7 @@ class BaseApiKeyListResource(Resource):
key = ApiToken.generate_api_key(self.token_prefix or "", 24) key = ApiToken.generate_api_key(self.token_prefix or "", 24)
api_token = ApiToken() api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id) setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id api_token.tenant_id = current_tenant_id
api_token.token = key api_token.token = key
api_token.type = self.resource_type api_token.type = self.resource_type
db.session.add(api_token) db.session.add(api_token)
@ -112,9 +108,8 @@ class BaseApiKeyResource(Resource):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
api_key_id = str(api_key_id) api_key_id = str(api_key_id)
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None _get_resource(resource_id, current_tenant_id, self.resource_model)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
@ -158,11 +153,6 @@ class AppApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for an app""" """Create a new API key for an app"""
return super().post(resource_id) return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app" resource_type = "app"
resource_model = App resource_model = App
resource_id_field = "app_id" resource_id_field = "app_id"
@ -179,11 +169,6 @@ class AppApiKeyResource(BaseApiKeyResource):
"""Delete an API key for an app""" """Delete an API key for an app"""
return super().delete(resource_id, api_key_id) return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app" resource_type = "app"
resource_model = App resource_model = App
resource_id_field = "app_id" resource_id_field = "app_id"
@ -208,11 +193,6 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for a dataset""" """Create a new API key for a dataset"""
return super().post(resource_id) return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset" resource_type = "dataset"
resource_model = Dataset resource_model = Dataset
resource_id_field = "dataset_id" resource_id_field = "dataset_id"
@ -229,11 +209,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
"""Delete an API key for a dataset""" """Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id) return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset" resource_type = "dataset"
resource_model = Dataset resource_model = Dataset
resource_id_field = "dataset_id" resource_id_field = "dataset_id"

View File

@ -25,11 +25,13 @@ class AdvancedPromptTemplateList(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("app_mode", type=str, required=True, location="args") reqparse.RequestParser()
parser.add_argument("model_mode", type=str, required=True, location="args") .add_argument("app_mode", type=str, required=True, location="args")
parser.add_argument("has_context", type=str, required=False, default="true", location="args") .add_argument("model_mode", type=str, required=True, location="args")
parser.add_argument("model_name", type=str, required=True, location="args") .add_argument("has_context", type=str, required=False, default="true", location="args")
.add_argument("model_name", type=str, required=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args) return AdvancedPromptTemplateService.get_prompt(args)

View File

@ -27,9 +27,11 @@ class AgentLogApi(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model): def get(self, app_model):
"""Get agent logs""" """Get agent logs"""
parser = reqparse.RequestParser() parser = (
parser.add_argument("message_id", type=uuid_value, required=True, location="args") reqparse.RequestParser()
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args") .add_argument("message_id", type=uuid_value, required=True, location="args")
.add_argument("conversation_id", type=uuid_value, required=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,15 +1,14 @@
from typing import Literal from typing import Literal
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required,
setup_required, setup_required,
) )
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -42,15 +41,15 @@ class AnnotationReplyActionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]): def post(self, app_id, action: Literal["enable", "disable"]):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("score_threshold", required=True, type=float, location="json") reqparse.RequestParser()
parser.add_argument("embedding_provider_name", required=True, type=str, location="json") .add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json") .add_argument("embedding_provider_name", required=True, type=str, location="json")
.add_argument("embedding_model_name", required=True, type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if action == "enable": if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_id) result = AppAnnotationService.enable_app_annotation(args, app_id)
@ -69,10 +68,8 @@ class AppAnnotationSettingDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_id): def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id) result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
return result, 200 return result, 200
@ -98,15 +95,12 @@ class AppAnnotationSettingUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, app_id, annotation_setting_id): def post(self, app_id, annotation_setting_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id) annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("score_threshold", required=True, type=float, location="json")
args = parser.parse_args() args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
@ -124,10 +118,8 @@ class AnnotationReplyActionStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id, action): def get(self, app_id, job_id, action):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id) job_id = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
cache_result = redis_client.get(app_annotation_job_key) cache_result = redis_client.get(app_annotation_job_key)
@ -159,10 +151,8 @@ class AnnotationApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_id): def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str) keyword = request.args.get("keyword", default="", type=str)
@ -198,14 +188,14 @@ class AnnotationApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
@edit_permission_required
def post(self, app_id): def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("question", required=True, type=str, location="json") reqparse.RequestParser()
parser.add_argument("answer", required=True, type=str, location="json") .add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation return annotation
@ -213,10 +203,8 @@ class AnnotationApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, app_id): def delete(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
# Use request.args.getlist to get annotation_ids array directly # Use request.args.getlist to get annotation_ids array directly
@ -249,10 +237,8 @@ class AnnotationExportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_id): def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)} response = {"data": marshal(annotation_list, annotation_fields)}
@ -271,16 +257,16 @@ class AnnotationUpdateDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_id, annotation_id): def post(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("question", required=True, type=str, location="json") reqparse.RequestParser()
parser.add_argument("answer", required=True, type=str, location="json") .add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation return annotation
@ -288,10 +274,8 @@ class AnnotationUpdateDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, app_id, annotation_id): def delete(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id) AppAnnotationService.delete_app_annotation(app_id, annotation_id)
@ -310,10 +294,8 @@ class AnnotationBatchImportApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id): def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
# check file # check file
if "file" not in request.files: if "file" not in request.files:
@ -341,10 +323,8 @@ class AnnotationBatchImportStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id): def get(self, app_id, job_id):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id) job_id = str(job_id)
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
@ -376,10 +356,8 @@ class AnnotationHitHistoryListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_id, annotation_id): def get(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
app_id = str(app_id) app_id = str(app_id)

View File

@ -1,7 +1,5 @@
import uuid import uuid
from typing import cast
from flask_login import current_user
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -12,15 +10,16 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required,
enterprise_license_required, enterprise_license_required,
setup_required, setup_required,
) )
from core.ops.ops_trace_manager import OpsTraceManager from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length from libs.validators import validate_description_length
from models import Account, App from models import App
from services.app_dsl_service import AppDslService, ImportMode from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
@ -56,6 +55,7 @@ class AppListApi(Resource):
@enterprise_license_required @enterprise_license_required
def get(self): def get(self):
"""Get app list""" """Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
def uuid_list(value): def uuid_list(value):
try: try:
@ -63,34 +63,36 @@ class AppListApi(Resource):
except ValueError: except ValueError:
abort(400, message="Invalid UUID format in tag_ids.") abort(400, message="Invalid UUID format in tag_ids.")
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument( .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
"mode", .add_argument(
type=str, "mode",
choices=[ type=str,
"completion", choices=[
"chat", "completion",
"advanced-chat", "chat",
"workflow", "advanced-chat",
"agent-chat", "workflow",
"channel", "agent-chat",
"all", "channel",
], "all",
default="all", ],
location="args", default="all",
required=False, location="args",
required=False,
)
.add_argument("name", type=str, location="args", required=False)
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
) )
parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
args = parser.parse_args() args = parser.parse_args()
# get app list # get app list
app_service = AppService() app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args) app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
if not app_pagination: if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@ -129,30 +131,26 @@ class AppListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self): def post(self):
"""Create app""" """Create app"""
parser = reqparse.RequestParser() current_user, current_tenant_id = current_account_with_tenant()
parser.add_argument("name", type=str, required=True, location="json") parser = (
parser.add_argument("description", type=validate_description_length, location="json") reqparse.RequestParser()
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") .add_argument("name", type=str, required=True, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_background", type=str, location="json") .add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if "mode" not in args or args["mode"] is None: if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required") raise BadRequest("mode is required")
app_service = AppService() app_service = AppService()
if not isinstance(current_user, Account): app = app_service.create_app(current_tenant_id, args, current_user)
raise ValueError("current_user must be an Account instance")
if current_user.current_tenant_id is None:
raise ValueError("current_user.current_tenant_id cannot be None")
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
return app, 201 return app, 201
@ -205,21 +203,20 @@ class AppApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@edit_permission_required
@marshal_with(app_detail_fields_with_site) @marshal_with(app_detail_fields_with_site)
def put(self, app_model): def put(self, app_model):
"""Update app""" """Update app"""
# The role of the current user in the ta table must be admin, owner, or editor parser = (
if not current_user.is_editor: reqparse.RequestParser()
raise Forbidden() .add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=validate_description_length, location="json")
parser = reqparse.RequestParser() .add_argument("icon_type", type=str, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json") .add_argument("icon", type=str, location="json")
parser.add_argument("description", type=validate_description_length, location="json") .add_argument("icon_background", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("use_icon_as_answer_icon", type=bool, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("max_active_requests", type=int, location="json")
parser.add_argument("icon_background", type=str, location="json") )
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
parser.add_argument("max_active_requests", type=int, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -248,12 +245,9 @@ class AppApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, app_model): def delete(self, app_model):
"""Delete app""" """Delete app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
app_service = AppService() app_service = AppService()
app_service.delete_app(app_model) app_service.delete_app(app_model)
@ -283,27 +277,28 @@ class AppCopyApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@edit_permission_required
@marshal_with(app_detail_fields_with_site) @marshal_with(app_detail_fields_with_site)
def post(self, app_model): def post(self, app_model):
"""Copy app""" """Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("name", type=str, location="json") reqparse.RequestParser()
parser.add_argument("description", type=validate_description_length, location="json") .add_argument("name", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("description", type=validate_description_length, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("icon_type", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json") .add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
account = cast(Account, current_user)
result = import_service.import_app( result = import_service.import_app(
account=account, account=current_user,
import_mode=ImportMode.YAML_CONTENT, import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content, yaml_content=yaml_content,
name=args.get("name"), name=args.get("name"),
@ -340,16 +335,15 @@ class AppExportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, app_model): def get(self, app_model):
"""Export app""" """Export app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params # Add include_secret params
parser = reqparse.RequestParser() parser = (
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") reqparse.RequestParser()
parser.add_argument("workflow_id", type=str, location="args") .add_argument("include_secret", type=inputs.boolean, default=False, location="args")
.add_argument("workflow_id", type=str, location="args")
)
args = parser.parse_args() args = parser.parse_args()
return { return {
@ -371,13 +365,9 @@ class AppNameApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -408,14 +398,13 @@ class AppIconApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor parser = (
if not current_user.is_editor: reqparse.RequestParser()
raise Forbidden() .add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
parser = reqparse.RequestParser() )
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -441,13 +430,9 @@ class AppSiteStatus(Resource):
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enable_site", type=bool, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -475,11 +460,11 @@ class AppApiStatus(Resource):
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
parser.add_argument("enable_api", type=bool, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -520,13 +505,14 @@ class AppTraceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, app_id): def post(self, app_id):
# add app trace # add app trace
if not current_user.is_editor: parser = (
raise Forbidden() reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("enabled", type=bool, required=True, location="json")
parser.add_argument("enabled", type=bool, required=True, location="json") .add_argument("tracing_provider", type=str, required=True, location="json")
parser.add_argument("tracing_provider", type=str, required=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
OpsTraceManager.update_app_tracing_config( OpsTraceManager.update_app_tracing_config(

View File

@ -1,20 +1,16 @@
from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required,
setup_required, setup_required,
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account
from models.model import App from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
@ -30,28 +26,29 @@ class AppImportApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(app_import_fields) @marshal_with(app_import_fields)
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self): def post(self):
# Check user role first # Check user role first
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("mode", type=str, required=True, location="json")
parser.add_argument("mode", type=str, required=True, location="json") .add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_content", type=str, location="json") .add_argument("yaml_url", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json") .add_argument("name", type=str, location="json")
parser.add_argument("name", type=str, location="json") .add_argument("description", type=str, location="json")
parser.add_argument("description", type=str, location="json") .add_argument("icon_type", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("icon", type=str, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("icon_background", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json") .add_argument("app_id", type=str, location="json")
parser.add_argument("app_id", type=str, location="json") )
args = parser.parse_args() args = parser.parse_args()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
# Import app # Import app
account = cast(Account, current_user) account = current_user
result = import_service.import_app( result = import_service.import_app(
account=account, account=account,
import_mode=args["mode"], import_mode=args["mode"],
@ -83,16 +80,16 @@ class AppImportConfirmApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_import_fields) @marshal_with(app_import_fields)
@edit_permission_required
def post(self, import_id): def post(self, import_id):
# Check user role first # Check user role first
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
# Confirm import # Confirm import
account = cast(Account, current_user) account = current_user
result = import_service.confirm_import(import_id=import_id, account=account) result = import_service.confirm_import(import_id=import_id, account=account)
session.commit() session.commit()
@ -109,10 +106,8 @@ class AppImportCheckDependenciesApi(Resource):
@get_app_model @get_app_model
@account_initialization_required @account_initialization_required
@marshal_with(app_import_check_dependencies_fields) @marshal_with(app_import_check_dependencies_fields)
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model) result = import_service.check_dependencies(app_model=app_model)

View File

@ -111,11 +111,13 @@ class ChatMessageTextApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, app_model: App): def post(self, app_model: App):
try: try:
parser = reqparse.RequestParser() parser = (
parser.add_argument("message_id", type=str, location="json") reqparse.RequestParser()
parser.add_argument("text", type=str, location="json") .add_argument("message_id", type=str, location="json")
parser.add_argument("voice", type=str, location="json") .add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json") .add_argument("voice", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args() args = parser.parse_args()
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
@ -166,8 +168,7 @@ class TextModesApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
try: try:
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
parser.add_argument("language", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
response = AudioService.transcript_tts_voices( response = AudioService.transcript_tts_voices(

View File

@ -2,7 +2,7 @@ import logging
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.console import api, console_ns from controllers.console import api, console_ns
@ -15,7 +15,7 @@ from controllers.console.app.error import (
ProviderQuotaExceededError, ProviderQuotaExceededError,
) )
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -64,13 +64,15 @@ class CompletionMessageApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model): def post(self, app_model):
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, location="json") reqparse.RequestParser()
parser.add_argument("query", type=str, location="json", default="") .add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json") .add_argument("query", type=str, location="json", default="")
parser.add_argument("model_config", type=dict, required=True, location="json") .add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") .add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] != "blocking" streaming = args["response_mode"] != "blocking"
@ -151,22 +153,19 @@ class ChatMessageApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model): def post(self, app_model):
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
if not current_user.has_edit_permission: .add_argument("query", type=str, required=True, location="json")
raise Forbidden() .add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
parser = reqparse.RequestParser() .add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("inputs", type=dict, required=True, location="json") .add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("query", type=str, required=True, location="json") .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("files", type=list, required=False, location="json") .add_argument("retriever_from", type=str, required=False, default="dev", location="json")
parser.add_argument("model_config", type=dict, required=True, location="json") )
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] != "blocking" streaming = args["response_mode"] != "blocking"

View File

@ -1,17 +1,16 @@
from datetime import datetime from datetime import datetime
import pytz # pip install pytz import pytz
import sqlalchemy as sa import sqlalchemy as sa
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy import func, or_ from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from fields.conversation_fields import ( from fields.conversation_fields import (
@ -22,8 +21,8 @@ from fields.conversation_fields import (
) )
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account, Conversation, EndUser, Message, MessageAnnotation from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode from models.model import AppMode
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
@ -57,18 +56,24 @@ class CompletionConversationApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_fields) @marshal_with(conversation_pagination_fields)
@edit_permission_required
def get(self, app_model): def get(self, app_model):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
parser = reqparse.RequestParser() reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") .add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument( .add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" "annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
) )
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
query = sa.select(Conversation).where( query = sa.select(Conversation).where(
@ -84,6 +89,7 @@ class CompletionConversationApi(Resource):
) )
account = current_user account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -137,9 +143,8 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_fields) @marshal_with(conversation_message_detail_fields)
@edit_permission_required
def get(self, app_model, conversation_id): def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id) return _get_conversation(app_model, conversation_id)
@ -154,14 +159,12 @@ class CompletionConversationDetailApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def delete(self, app_model, conversation_id): def delete(self, app_model, conversation_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -206,26 +209,32 @@ class ChatConversationApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_with_summary_pagination_fields) @marshal_with(conversation_with_summary_pagination_fields)
@edit_permission_required
def get(self, app_model): def get(self, app_model):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
parser = reqparse.RequestParser() reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") .add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument( .add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" "annotation_status",
) type=str,
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args") choices=["annotated", "not_annotated", "all"],
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args") default="all",
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") location="args",
parser.add_argument( )
"sort_by", .add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
type=str, .add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
choices=["created_at", "-created_at", "updated_at", "-updated_at"], .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
required=False, .add_argument(
default="-updated_at", "sort_by",
location="args", type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
) )
args = parser.parse_args() args = parser.parse_args()
@ -260,6 +269,7 @@ class ChatConversationApi(Resource):
) )
account = current_user account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -341,9 +351,8 @@ class ChatConversationDetailApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_detail_fields) @marshal_with(conversation_detail_fields)
@edit_permission_required
def get(self, app_model, conversation_id): def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id) return _get_conversation(app_model, conversation_id)
@ -358,14 +367,12 @@ class ChatConversationDetailApi(Resource):
@login_required @login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, app_model, conversation_id): def delete(self, app_model, conversation_id):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user) ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -374,6 +381,7 @@ class ChatConversationDetailApi(Resource):
def _get_conversation(app_model, conversation_id): def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)

View File

@ -29,8 +29,7 @@ class ConversationVariablesApi(Resource):
@get_app_model(mode=AppMode.ADVANCED_CHAT) @get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_fields) @marshal_with(paginated_conversation_variable_fields)
def get(self, app_model): def get(self, app_model):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
parser.add_argument("conversation_id", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
stmt = ( stmt = (

View File

@ -1,6 +1,5 @@
from collections.abc import Sequence from collections.abc import Sequence
from flask_login import current_user
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns from controllers.console import api, console_ns
@ -17,7 +16,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
from core.llm_generator.llm_generator import LLMGenerator from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import App from models import App
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@ -43,16 +42,18 @@ class RuleGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") .add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") .add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try: try:
rules = LLMGenerator.generate_rule_config( rules = LLMGenerator.generate_rule_config(
tenant_id=account.current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
no_variable=args["no_variable"], no_variable=args["no_variable"],
@ -93,17 +94,19 @@ class RuleCodeGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") .add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") .add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json") .add_argument("no_variable", type=bool, required=True, default=False, location="json")
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try: try:
code_result = LLMGenerator.generate_code( code_result = LLMGenerator.generate_code(
tenant_id=account.current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
code_language=args["code_language"], code_language=args["code_language"],
@ -140,15 +143,17 @@ class RuleStructuredOutputGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") .add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try: try:
structured_output = LLMGenerator.generate_structured_output( structured_output = LLMGenerator.generate_structured_output(
tenant_id=account.current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
) )
@ -189,15 +194,18 @@ class InstructionGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("flow_id", type=str, required=True, default="", location="json") reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=False, default="", location="json") .add_argument("flow_id", type=str, required=True, default="", location="json")
parser.add_argument("current", type=str, required=False, default="", location="json") .add_argument("node_id", type=str, required=False, default="", location="json")
parser.add_argument("language", type=str, required=False, default="javascript", location="json") .add_argument("current", type=str, required=False, default="", location="json")
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") .add_argument("language", type=str, required=False, default="javascript", location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") .add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json") .add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("ideal_output", type=str, required=False, default="", location="json")
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
code_template = ( code_template = (
Python3CodeProvider.get_default_code() Python3CodeProvider.get_default_code()
if args["language"] == "python" if args["language"] == "python"
@ -222,21 +230,21 @@ class InstructionGenerateApi(Resource):
match node_type: match node_type:
case "llm": case "llm":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_user.current_tenant_id, current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
no_variable=True, no_variable=True,
) )
case "agent": case "agent":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_user.current_tenant_id, current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
no_variable=True, no_variable=True,
) )
case "code": case "code":
return LLMGenerator.generate_code( return LLMGenerator.generate_code(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
code_language=args["language"], code_language=args["language"],
@ -245,7 +253,7 @@ class InstructionGenerateApi(Resource):
return {"error": f"invalid node type: {node_type}"} return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy( return LLMGenerator.instruction_modify_legacy(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
flow_id=args["flow_id"], flow_id=args["flow_id"],
current=args["current"], current=args["current"],
instruction=args["instruction"], instruction=args["instruction"],
@ -254,7 +262,7 @@ class InstructionGenerateApi(Resource):
) )
if args["node_id"] != "" and args["current"] != "": # For workflow node if args["node_id"] != "" and args["current"] != "": # For workflow node
return LLMGenerator.instruction_modify_workflow( return LLMGenerator.instruction_modify_workflow(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
flow_id=args["flow_id"], flow_id=args["flow_id"],
node_id=args["node_id"], node_id=args["node_id"],
current=args["current"], current=args["current"],
@ -293,8 +301,7 @@ class InstructionGenerationTemplateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
parser.add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args() args = parser.parse_args()
match args["type"]: match args["type"]:
case "prompt": case "prompt":

View File

@ -1,16 +1,15 @@
import json import json
from enum import StrEnum from enum import StrEnum
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_server_fields from fields.app_fields import app_server_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer from models.model import AppMCPServer
@ -25,9 +24,9 @@ class AppMCPServerController(Resource):
@api.doc(description="Get MCP server configuration for an application") @api.doc(description="Get MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields) @api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
@setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@setup_required
@get_app_model @get_app_model
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
def get(self, app_model): def get(self, app_model):
@ -48,17 +47,19 @@ class AppMCPServerController(Resource):
) )
@api.response(201, "MCP server configuration created successfully", app_server_fields) @api.response(201, "MCP server configuration created successfully", app_server_fields)
@api.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@login_required
@setup_required
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
@edit_permission_required
def post(self, app_model): def post(self, app_model):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise NotFound() parser = (
parser = reqparse.RequestParser() reqparse.RequestParser()
parser.add_argument("description", type=str, required=False, location="json") .add_argument("description", type=str, required=False, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json") .add_argument("parameters", type=dict, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
description = args.get("description") description = args.get("description")
@ -71,7 +72,7 @@ class AppMCPServerController(Resource):
parameters=json.dumps(args["parameters"], ensure_ascii=False), parameters=json.dumps(args["parameters"], ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE, status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id, app_id=app_model.id,
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
server_code=AppMCPServer.generate_server_code(16), server_code=AppMCPServer.generate_server_code(16),
) )
db.session.add(server) db.session.add(server)
@ -95,19 +96,20 @@ class AppMCPServerController(Resource):
@api.response(200, "MCP server configuration updated successfully", app_server_fields) @api.response(200, "MCP server configuration updated successfully", app_server_fields)
@api.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@api.response(404, "Server not found") @api.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model @get_app_model
@login_required
@setup_required
@account_initialization_required
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
@edit_permission_required
def put(self, app_model): def put(self, app_model):
if not current_user.is_editor: parser = (
raise NotFound() reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("id", type=str, required=True, location="json")
parser.add_argument("id", type=str, required=True, location="json") .add_argument("description", type=str, required=False, location="json")
parser.add_argument("description", type=str, required=False, location="json") .add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json") .add_argument("status", type=str, required=False, location="json")
parser.add_argument("status", type=str, required=False, location="json") )
args = parser.parse_args() args = parser.parse_args()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first() server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
if not server: if not server:
@ -142,13 +144,13 @@ class AppMCPServerRefreshController(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
@edit_permission_required
def get(self, server_id): def get(self, server_id):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise NotFound()
server = ( server = (
db.session.query(AppMCPServer) db.session.query(AppMCPServer)
.where(AppMCPServer.id == server_id) .where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_user.current_tenant_id) .where(AppMCPServer.tenant_id == current_tenant_id)
.first() .first()
) )
if not server: if not server:

View File

@ -3,7 +3,7 @@ import logging
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
from sqlalchemy import exists, select from sqlalchemy import exists, select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
@ -17,6 +17,7 @@ from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDi
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required,
setup_required, setup_required,
) )
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -26,8 +27,7 @@ from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
@ -56,19 +56,19 @@ class ChatMessageListApi(Resource):
) )
@api.response(200, "Success", message_infinite_scroll_pagination_fields) @api.response(200, "Success", message_infinite_scroll_pagination_fields)
@api.response(404, "Conversation not found") @api.response(404, "Conversation not found")
@setup_required
@login_required @login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required @account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
@edit_permission_required
def get(self, app_model): def get(self, app_model):
if not isinstance(current_user, Account) or not current_user.has_edit_permission: parser = (
raise Forbidden() reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser = reqparse.RequestParser() .add_argument("first_id", type=uuid_value, location="args")
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("first_id", type=uuid_value, location="args") )
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
conversation = ( conversation = (
@ -154,12 +154,13 @@ class MessageFeedbackApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, app_model): def post(self, app_model):
if current_user is None: current_user, _ = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("message_id", required=True, type=uuid_value, location="json") reqparse.RequestParser()
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") .add_argument("message_id", required=True, type=uuid_value, location="json")
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
)
args = parser.parse_args() args = parser.parse_args()
message_id = str(args["message_id"]) message_id = str(args["message_id"])
@ -211,23 +212,21 @@ class MessageAnnotationApi(Resource):
) )
@api.response(200, "Annotation created successfully", annotation_fields) @api.response(200, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@marshal_with(annotation_fields)
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@get_app_model @account_initialization_required
@marshal_with(annotation_fields) @edit_permission_required
def post(self, app_model): def post(self, app_model):
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
if not current_user.has_edit_permission: .add_argument("message_id", required=False, type=uuid_value, location="json")
raise Forbidden() .add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
parser = reqparse.RequestParser() .add_argument("annotation_reply", required=False, type=dict, location="json")
parser.add_argument("message_id", required=False, type=uuid_value, location="json") )
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
@ -270,6 +269,7 @@ class MessageSuggestedQuestionApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model, message_id): def get(self, app_model, message_id):
current_user, _ = current_account_with_tenant()
message_id = str(message_id) message_id = str(message_id)
try: try:
@ -304,12 +304,12 @@ class MessageApi(Resource):
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@api.response(200, "Message retrieved successfully", message_detail_fields) @api.response(200, "Message retrieved successfully", message_detail_fields)
@api.response(404, "Message not found") @api.response(404, "Message not found")
@get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model
@marshal_with(message_detail_fields) @marshal_with(message_detail_fields)
def get(self, app_model, message_id): def get(self, app_model, message_id: str):
message_id = str(message_id) message_id = str(message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()

View File

@ -2,7 +2,6 @@ import json
from typing import cast from typing import cast
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -15,8 +14,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated from events.app_event import app_model_config_was_updated
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from models.model import AppMode, AppModelConfig from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService from services.app_model_config_service import AppModelConfigService
@ -54,16 +52,14 @@ class ModelConfigResource(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model): def post(self, app_model):
"""Modify app model config""" """Modify app model config"""
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise Forbidden()
if not current_user.has_edit_permission: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
# validate config # validate config
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
config=cast(dict, request.json), config=cast(dict, request.json),
app_mode=AppMode.value_of(app_model.mode), app_mode=AppMode.value_of(app_model.mode),
) )
@ -95,12 +91,12 @@ class ModelConfigResource(Resource):
# get tool # get tool
try: try:
tool_runtime = ToolManager.get_agent_tool_runtime( tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
app_id=app_model.id, app_id=app_model.id,
agent_tool=agent_tool_entity, agent_tool=agent_tool_entity,
) )
manager = ToolParameterConfigurationManager( manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
tool_runtime=tool_runtime, tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id, provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type, provider_type=agent_tool_entity.provider_type,
@ -134,7 +130,7 @@ class ModelConfigResource(Resource):
else: else:
try: try:
tool_runtime = ToolManager.get_agent_tool_runtime( tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
app_id=app_model.id, app_id=app_model.id,
agent_tool=agent_tool_entity, agent_tool=agent_tool_entity,
) )
@ -142,7 +138,7 @@ class ModelConfigResource(Resource):
continue continue
manager = ToolParameterConfigurationManager( manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
tool_runtime=tool_runtime, tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id, provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type, provider_type=agent_tool_entity.provider_type,

View File

@ -30,8 +30,7 @@ class TraceAppConfigApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
parser.add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -63,9 +62,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, app_id): def post(self, app_id):
"""Create a new trace app configuration""" """Create a new trace app configuration"""
parser = reqparse.RequestParser() parser = (
parser.add_argument("tracing_provider", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("tracing_config", type=dict, required=True, location="json") .add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -99,9 +100,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def patch(self, app_id): def patch(self, app_id):
"""Update an existing trace app configuration""" """Update an existing trace app configuration"""
parser = reqparse.RequestParser() parser = (
parser.add_argument("tracing_provider", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("tracing_config", type=dict, required=True, location="json") .add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -129,8 +132,7 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def delete(self, app_id): def delete(self, app_id):
"""Delete an existing trace app configuration""" """Delete an existing trace app configuration"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
parser.add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
try: try:

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -9,30 +8,36 @@ from controllers.console.wraps import account_initialization_required, setup_req
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_site_fields from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account, Site from models import Site
def parse_app_site_args(): def parse_app_site_args():
parser = reqparse.RequestParser() parser = (
parser.add_argument("title", type=str, required=False, location="json") reqparse.RequestParser()
parser.add_argument("icon_type", type=str, required=False, location="json") .add_argument("title", type=str, required=False, location="json")
parser.add_argument("icon", type=str, required=False, location="json") .add_argument("icon_type", type=str, required=False, location="json")
parser.add_argument("icon_background", type=str, required=False, location="json") .add_argument("icon", type=str, required=False, location="json")
parser.add_argument("description", type=str, required=False, location="json") .add_argument("icon_background", type=str, required=False, location="json")
parser.add_argument("default_language", type=supported_language, required=False, location="json") .add_argument("description", type=str, required=False, location="json")
parser.add_argument("chat_color_theme", type=str, required=False, location="json") .add_argument("default_language", type=supported_language, required=False, location="json")
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") .add_argument("chat_color_theme", type=str, required=False, location="json")
parser.add_argument("customize_domain", type=str, required=False, location="json") .add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
parser.add_argument("copyright", type=str, required=False, location="json") .add_argument("customize_domain", type=str, required=False, location="json")
parser.add_argument("privacy_policy", type=str, required=False, location="json") .add_argument("copyright", type=str, required=False, location="json")
parser.add_argument("custom_disclaimer", type=str, required=False, location="json") .add_argument("privacy_policy", type=str, required=False, location="json")
parser.add_argument( .add_argument("custom_disclaimer", type=str, required=False, location="json")
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json" .add_argument(
"customize_token_strategy",
type=str,
choices=["must", "allow", "not_allow"],
required=False,
location="json",
)
.add_argument("prompt_public", type=bool, required=False, location="json")
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
) )
parser.add_argument("prompt_public", type=bool, required=False, location="json")
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
return parser.parse_args() return parser.parse_args()
@ -76,9 +81,10 @@ class AppSite(Resource):
@marshal_with(app_site_fields) @marshal_with(app_site_fields)
def post(self, app_model): def post(self, app_model):
args = parse_app_site_args() args = parse_app_site_args()
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be editor, admin, or owner # The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
@ -107,8 +113,6 @@ class AppSite(Resource):
if value is not None: if value is not None:
setattr(site, attr_name, value) setattr(site, attr_name, value)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = naive_utc_now() site.updated_at = naive_utc_now()
db.session.commit() db.session.commit()
@ -131,6 +135,8 @@ class AppSiteAccessTokenReset(Resource):
@marshal_with(app_site_fields) @marshal_with(app_site_fields)
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
@ -140,8 +146,6 @@ class AppSiteAccessTokenReset(Resource):
raise NotFound raise NotFound
site.code = Site.generate_code(16) site.code = Site.generate_code(16)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = naive_utc_now() site.updated_at = naive_utc_now()
db.session.commit() db.session.commit()

View File

@ -4,7 +4,6 @@ from decimal import Decimal
import pytz import pytz
import sqlalchemy as sa import sqlalchemy as sa
from flask import jsonify from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns from controllers.console import api, console_ns
@ -13,7 +12,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import AppMode, Message from models import AppMode, Message
@ -37,11 +36,13 @@ class DailyMessageStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -53,6 +54,7 @@ WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -109,13 +111,15 @@ class DailyConversationStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -175,11 +179,13 @@ class DailyTerminalsStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -191,7 +197,7 @@ WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -247,11 +253,13 @@ class DailyTokenCostStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -264,7 +272,7 @@ WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -322,11 +330,13 @@ class AverageSessionInteractionStatistic(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -346,7 +356,7 @@ FROM
c.app_id = :app_id c.app_id = :app_id
AND m.invoke_from != :invoke_from""" AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -413,11 +423,13 @@ class UserSatisfactionRateStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -433,7 +445,7 @@ WHERE
m.app_id = :app_id m.app_id = :app_id
AND m.invoke_from != :invoke_from""" AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -494,11 +506,13 @@ class AverageResponseTimeStatistic(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -510,7 +524,7 @@ WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -566,11 +580,13 @@ class TokensPerSecondStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -585,7 +601,7 @@ WHERE
app_id = :app_id app_id = :app_id
AND invoke_from != :invoke_from""" AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER} arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc

View File

@ -12,7 +12,7 @@ import services
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
@ -27,9 +27,8 @@ from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper from libs import helper
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models import App from models import App
from models.account import Account
from models.model import AppMode from models.model import AppMode
from models.workflow import Workflow from models.workflow import Workflow
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
@ -70,15 +69,11 @@ class DraftWorkflowApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get draft workflow Get draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model # fetch draft workflow by app_model
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model) workflow = workflow_service.get_draft_workflow(app_model=app_model)
@ -110,24 +105,24 @@ class DraftWorkflowApi(Resource):
@api.response(200, "Draft workflow synced successfully", workflow_fields) @api.response(200, "Draft workflow synced successfully", workflow_fields)
@api.response(400, "Invalid workflow configuration") @api.response(400, "Invalid workflow configuration")
@api.response(403, "Permission denied") @api.response(403, "Permission denied")
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Sync draft workflow Sync draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant()
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type: if "application/json" in content_type:
parser = reqparse.RequestParser() parser = (
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("features", type=dict, required=True, nullable=False, location="json") .add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("hash", type=str, required=False, location="json") .add_argument("features", type=dict, required=True, nullable=False, location="json")
parser.add_argument("environment_variables", type=list, required=True, location="json") .add_argument("hash", type=str, required=False, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json") .add_argument("environment_variables", type=list, required=True, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
@ -149,10 +144,6 @@ class DraftWorkflowApi(Resource):
return {"message": "Invalid JSON data"}, 400 return {"message": "Invalid JSON data"}, 400
else: else:
abort(415) abort(415)
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService() workflow_service = WorkflowService()
try: try:
@ -206,24 +197,21 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Run draft workflow Run draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant()
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
.add_argument("inputs", type=dict, location="json")
parser = reqparse.RequestParser() .add_argument("query", type=str, required=True, location="json", default="")
parser.add_argument("inputs", type=dict, location="json") .add_argument("files", type=list, location="json")
parser.add_argument("query", type=str, required=True, location="json", default="") .add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("files", type=list, location="json") .add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json") )
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -271,18 +259,13 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow iteration node Run draft workflow iteration node
""" """
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden() parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -323,18 +306,13 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow iteration node Run draft workflow iteration node
""" """
# The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -375,19 +353,13 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow loop node Run draft workflow loop node
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -428,19 +400,13 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow loop node Run draft workflow loop node
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -480,20 +446,17 @@ class DraftWorkflowRunApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Run draft workflow Run draft workflow
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
# The role of the current user in the ta table must be admin, owner, or editor .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
if not current_user.has_edit_permission: .add_argument("files", type=list, required=False, location="json")
raise Forbidden() )
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
@ -526,17 +489,11 @@ class WorkflowTaskStopApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, task_id: str): def post(self, app_model: App, task_id: str):
""" """
Stop workflow task Stop workflow task
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# Stop using both mechanisms for backward compatibility # Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check) # Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id) AppQueueManager.set_stop_flag_no_user_check(task_id)
@ -568,21 +525,18 @@ class DraftWorkflowNodeRunApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_fields) @marshal_with(workflow_run_node_execution_fields)
@edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow node Run draft workflow node
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
# The role of the current user in the ta table must be admin, owner, or editor .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
if not current_user.has_edit_permission: .add_argument("query", type=str, required=False, location="json", default="")
raise Forbidden() .add_argument("files", type=list, location="json", default=[])
)
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("query", type=str, required=False, location="json", default="")
parser.add_argument("files", type=list, location="json", default=[])
args = parser.parse_args() args = parser.parse_args()
user_inputs = args.get("inputs") user_inputs = args.get("inputs")
@ -622,17 +576,11 @@ class PublishedWorkflowApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get published workflow Get published workflow
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# fetch published workflow by app_model # fetch published workflow by app_model
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app_model) workflow = workflow_service.get_published_workflow(app_model=app_model)
@ -644,19 +592,17 @@ class PublishedWorkflowApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Publish workflow Publish workflow
""" """
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
# The role of the current user in the ta table must be admin, owner, or editor reqparse.RequestParser()
if not current_user.has_edit_permission: .add_argument("marked_name", type=str, required=False, default="", location="json")
raise Forbidden() .add_argument("marked_comment", type=str, required=False, default="", location="json")
)
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
args = parser.parse_args() args = parser.parse_args()
# Validate name and comment length # Validate name and comment length
@ -702,17 +648,11 @@ class DefaultBlockConfigsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get default block config Get default block config
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs # Get default block configs
workflow_service = WorkflowService() workflow_service = WorkflowService()
return workflow_service.get_default_block_configs() return workflow_service.get_default_block_configs()
@ -729,18 +669,12 @@ class DefaultBlockConfigApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App, block_type: str): def get(self, app_model: App, block_type: str):
""" """
Get default block config Get default block config
""" """
if not isinstance(current_user, Account): parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
q = args.get("q") q = args.get("q")
@ -769,24 +703,23 @@ class ConvertToWorkflowApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION]) @get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
@edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
Convert basic mode of chatbot app to workflow mode Convert basic mode of chatbot app to workflow mode
Convert expert mode of chatbot app to workflow mode Convert expert mode of chatbot app to workflow mode
Convert Completion App to Workflow App Convert Completion App to Workflow App
""" """
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
if request.data: if request.data:
parser = reqparse.RequestParser() parser = (
parser.add_argument("name", type=str, required=False, nullable=True, location="json") reqparse.RequestParser()
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json") .add_argument("name", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon", type=str, required=False, nullable=True, location="json") .add_argument("icon_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") .add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
else: else:
args = {} args = {}
@ -812,21 +745,20 @@ class PublishedAllWorkflowApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_pagination_fields) @marshal_with(workflow_pagination_fields)
@edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get published workflows Get published workflows
""" """
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account): parser = (
raise Forbidden() reqparse.RequestParser()
if not current_user.has_edit_permission: .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
raise Forbidden() .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument("user_id", type=str, required=False, location="args")
parser = reqparse.RequestParser() .add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") )
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("user_id", type=str, required=False, location="args")
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
args = parser.parse_args() args = parser.parse_args()
page = int(args.get("page", 1)) page = int(args.get("page", 1))
limit = int(args.get("limit", 10)) limit = int(args.get("limit", 10))
@ -879,19 +811,17 @@ class WorkflowByIdApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
@edit_permission_required
def patch(self, app_model: App, workflow_id: str): def patch(self, app_model: App, workflow_id: str):
""" """
Update workflow attributes Update workflow attributes
""" """
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
# Check permission reqparse.RequestParser()
if not current_user.has_edit_permission: .add_argument("marked_name", type=str, required=False, location="json")
raise Forbidden() .add_argument("marked_comment", type=str, required=False, location="json")
)
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, location="json")
parser.add_argument("marked_comment", type=str, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
# Validate name and comment length # Validate name and comment length
@ -934,16 +864,11 @@ class WorkflowByIdApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def delete(self, app_model: App, workflow_id: str): def delete(self, app_model: App, workflow_id: str):
""" """
Delete workflow Delete workflow
""" """
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission
if not current_user.has_edit_permission:
raise Forbidden()
workflow_service = WorkflowService() workflow_service = WorkflowService()
# Create a session and manage the transaction # Create a session and manage the transaction

View File

@ -42,33 +42,35 @@ class WorkflowAppLogApi(Resource):
""" """
Get workflow app logs Get workflow app logs
""" """
parser = reqparse.RequestParser() parser = (
parser.add_argument("keyword", type=str, location="args") reqparse.RequestParser()
parser.add_argument( .add_argument("keyword", type=str, location="args")
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args" .add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
) )
parser.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
parser.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
parser.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
args.status = WorkflowExecutionStatus(args.status) if args.status else None args.status = WorkflowExecutionStatus(args.status) if args.status else None

View File

@ -22,8 +22,7 @@ from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models import App, AppMode from models import Account, App, AppMode
from models.account import Account
from models.workflow import WorkflowDraftVariable from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@ -58,16 +57,18 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
def _create_pagination_parser(): def _create_pagination_parser():
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"page", .add_argument(
type=inputs.int_range(1, 100_000), "page",
required=False, type=inputs.int_range(1, 100_000),
default=1, required=False,
location="args", default=1,
help="the page of data requested", location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
) )
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser return parser
@ -320,10 +321,11 @@ class VariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# } # }
parser = reqparse.RequestParser() parser = (
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") reqparse.RequestParser()
# Parse 'value' field as-is to maintain its original data structure .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=db.session(), session=db.session(),

View File

@ -1,6 +1,5 @@
from typing import cast from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
@ -9,15 +8,81 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.workflow_run_fields import ( from fields.workflow_run_fields import (
advanced_chat_workflow_run_pagination_fields, advanced_chat_workflow_run_pagination_fields,
workflow_run_count_fields,
workflow_run_detail_fields, workflow_run_detail_fields,
workflow_run_node_execution_list_fields, workflow_run_node_execution_list_fields,
workflow_run_pagination_fields, workflow_run_pagination_fields,
) )
from libs.custom_inputs import time_duration
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import login_required from libs.login import current_user, login_required
from models import Account, App, AppMode, EndUser from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
from services.workflow_run_service import WorkflowRunService from services.workflow_run_service import WorkflowRunService
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
def _parse_workflow_run_list_args():
"""
Parse common arguments for workflow run list endpoints.
Returns:
Parsed arguments containing last_id, limit, status, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
return parser.parse_args()
def _parse_workflow_run_count_args():
"""
Parse common arguments for workflow run count endpoints.
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = reqparse.RequestParser()
parser.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
parser.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
parser.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
class AdvancedChatAppWorkflowRunListApi(Resource): class AdvancedChatAppWorkflowRunListApi(Resource):
@ -25,6 +90,8 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@api.doc(description="Get advanced chat workflow run list") @api.doc(description="Get advanced chat workflow run list")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields) @api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@ -35,13 +102,64 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
""" """
Get advanced chat app workflow run list Get advanced chat app workflow run list
""" """
parser = reqparse.RequestParser() args = _parse_workflow_run_list_args()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") # Default to DEBUGGING if not specified
args = parser.parse_args() triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService() workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args) result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
app_model=app_model, args=args, triggered_from=triggered_from
)
return result
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
class AdvancedChatAppWorkflowRunCountApi(Resource):
@api.doc("get_advanced_chat_workflow_runs_count")
@api.doc(description="Get advanced chat workflow runs count statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
)
}
)
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@marshal_with(workflow_run_count_fields)
def get(self, app_model: App):
"""
Get advanced chat workflow runs count statistics
"""
args = _parse_workflow_run_count_args()
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_runs_count(
app_model=app_model,
status=args.get("status"),
time_range=args.get("time_range"),
triggered_from=triggered_from,
)
return result return result
@ -52,6 +170,8 @@ class WorkflowRunListApi(Resource):
@api.doc(description="Get workflow run list") @api.doc(description="Get workflow run list")
@api.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields) @api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@ -62,13 +182,64 @@ class WorkflowRunListApi(Resource):
""" """
Get workflow run list Get workflow run list
""" """
parser = reqparse.RequestParser() args = _parse_workflow_run_list_args()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") # Default to DEBUGGING for workflow if not specified (backward compatibility)
args = parser.parse_args() triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService() workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args) result = workflow_run_service.get_paginate_workflow_runs(
app_model=app_model, args=args, triggered_from=triggered_from
)
return result
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
class WorkflowRunCountApi(Resource):
@api.doc("get_workflow_runs_count")
@api.doc(description="Get workflow runs count statistics")
@api.doc(params={"app_id": "Application ID"})
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
@api.doc(
params={
"time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
)
}
)
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_count_fields)
def get(self, app_model: App):
"""
Get workflow runs count statistics
"""
args = _parse_workflow_run_count_args()
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_workflow_runs_count(
app_model=app_model,
status=args.get("status"),
time_range=args.get("time_range"),
triggered_from=triggered_from,
)
return result return result

View File

@ -4,7 +4,6 @@ from decimal import Decimal
import pytz import pytz
import sqlalchemy as sa import sqlalchemy as sa
from flask import jsonify from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from controllers.console import api, console_ns from controllers.console import api, console_ns
@ -12,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode from models.model import AppMode
@ -29,11 +28,13 @@ class WorkflowDailyRunsStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -49,7 +50,7 @@ WHERE
"app_id": app_model.id, "app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
} }
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -97,11 +98,13 @@ class WorkflowDailyTerminalsStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -117,7 +120,7 @@ WHERE
"app_id": app_model.id, "app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
} }
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -165,11 +168,13 @@ class WorkflowDailyTokenCostStatistic(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -185,7 +190,7 @@ WHERE
"app_id": app_model.id, "app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
} }
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
@ -238,11 +243,13 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.WORKFLOW])
def get(self, app_model): def get(self, app_model):
account = current_user account, _ = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") reqparse.RequestParser()
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") .add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT
@ -271,7 +278,7 @@ GROUP BY
"app_id": app_model.id, "app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN, "triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
} }
assert account.timezone is not None
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc

View File

@ -4,28 +4,29 @@ from typing import ParamSpec, TypeVar, Union
from controllers.console.app.error import AppNotFoundError from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user from libs.login import current_account_with_tenant
from models import App, AppMode from models import App, AppMode
from models.account import Account
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
P1 = ParamSpec("P1")
R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None: def _load_app_model(app_id: str) -> App | None:
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
app_model = ( app_model = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first() .first()
) )
return app_model return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None): def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]): def decorator(view_func: Callable[P1, R1]):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs): def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
if not kwargs.get("app_id"): if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters") raise ValueError("missing app_id in path parameters")

View File

@ -7,18 +7,14 @@ from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus from models import AccountStatus
from services.account_service import AccountService, RegisterService from services.account_service import AccountService, RegisterService
active_check_parser = reqparse.RequestParser() active_check_parser = (
active_check_parser.add_argument( reqparse.RequestParser()
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID" .add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
) .add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
active_check_parser.add_argument( .add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
"email", type=email, required=False, nullable=True, location="args", help="Email address"
)
active_check_parser.add_argument(
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
) )
@ -60,15 +56,15 @@ class ActivateCheckApi(Resource):
return {"is_valid": False} return {"is_valid": False}
active_parser = reqparse.RequestParser() active_parser = (
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") reqparse.RequestParser()
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json") .add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("email", type=email, required=False, nullable=True, location="json")
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") .add_argument("token", type=str, required=True, nullable=False, location="json")
active_parser.add_argument( .add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
"interface_language", type=supported_language, required=True, nullable=False, location="json" .add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
) )
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
@console_ns.route("/activate") @console_ns.route("/activate")

View File

@ -1,10 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required from ..wraps import account_initialization_required, setup_required
@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) _, current_tenant_id = current_account_with_tenant()
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
if data_source_api_key_bindings: if data_source_api_key_bindings:
return { return {
"sources": [ "sources": [
@ -41,16 +41,20 @@ class ApiKeyAuthDataSourceBinding(Resource):
@account_initialization_required @account_initialization_required
def post(self): def post(self):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("category", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="json") .add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") .add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args) ApiKeyAuthService.validate_api_key_auth_args(args)
try: try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
except Exception as e: except Exception as e:
raise ApiKeyAuthFailedError(str(e)) raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -63,9 +67,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@account_initialization_required @account_initialization_required
def delete(self, binding_id): def delete(self, binding_id):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -2,13 +2,12 @@ import logging
import httpx import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api, console_ns from controllers.console import api, console_ns
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required from ..wraps import account_initialization_required, setup_required
@ -45,6 +44,7 @@ class OAuthDataSource(Resource):
@api.response(403, "Admin privileges required") @api.response(403, "Admin privileges required")
def get(self, provider: str): def get(self, provider: str):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()

View File

@ -19,7 +19,7 @@ from controllers.console.wraps import email_password_login_enabled, email_regist
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.account import Account from models import Account
from services.account_service import AccountService from services.account_service import AccountService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.account import AccountNotFoundError, AccountRegisterError
@ -31,9 +31,11 @@ class EmailRegisterSendEmailApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
@ -59,10 +61,12 @@ class EmailRegisterCheckApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") .add_argument("email", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
@ -100,10 +104,12 @@ class EmailRegisterResetApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("token", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") .add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# Validate passwords match # Validate passwords match

View File

@ -20,7 +20,7 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models.account import Account from models import Account
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -54,9 +54,11 @@ class ForgotPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
@ -111,10 +113,12 @@ class ForgotPasswordCheckApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") .add_argument("email", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
@ -169,10 +173,12 @@ class ForgotPasswordResetApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("token", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") .add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# Validate passwords match # Validate passwords match

View File

@ -1,7 +1,5 @@
from typing import cast
import flask_login import flask_login
from flask import request from flask import make_response, request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
import services import services
@ -26,7 +24,17 @@ from controllers.console.error import (
from controllers.console.wraps import email_password_login_enabled, setup_required from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
from models.account import Account from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie,
extract_access_token,
extract_csrf_token,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import AccountRegisterError from services.errors.account import AccountRegisterError
@ -42,11 +50,13 @@ class LoginApi(Resource):
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
"""Authenticate user and login.""" """Authenticate user and login."""
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("password", type=str, required=True, location="json") .add_argument("email", type=email, required=True, location="json")
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") .add_argument("password", type=str, required=True, location="json")
parser.add_argument("invite_token", type=str, required=False, default=None, location="json") .add_argument("remember_me", type=bool, required=False, default=False, location="json")
.add_argument("invite_token", type=str, required=False, default=None, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
@ -89,19 +99,36 @@ class LoginApi(Resource):
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
@console_ns.route("/logout") @console_ns.route("/logout")
class LogoutApi(Resource): class LogoutApi(Resource):
@setup_required @setup_required
def get(self): def post(self):
account = cast(Account, flask_login.current_user) current_user, _ = current_account_with_tenant()
account = current_user
if isinstance(account, flask_login.AnonymousUserMixin): if isinstance(account, flask_login.AnonymousUserMixin):
return {"result": "success"} response = make_response({"result": "success"})
AccountService.logout(account=account) else:
flask_login.logout_user() AccountService.logout(account=account)
return {"result": "success"} flask_login.logout_user()
response = make_response({"result": "success"})
# Clear cookies on logout
clear_access_token_from_cookie(response)
clear_refresh_token_from_cookie(response)
clear_csrf_token_from_cookie(response)
return response
@console_ns.route("/reset-password") @console_ns.route("/reset-password")
@ -109,9 +136,11 @@ class ResetPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if args["language"] is not None and args["language"] == "zh-Hans": if args["language"] is not None and args["language"] == "zh-Hans":
@ -137,9 +166,11 @@ class ResetPasswordSendEmailApi(Resource):
class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginSendEmailApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("language", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
@ -170,10 +201,12 @@ class EmailCodeLoginSendEmailApi(Resource):
class EmailCodeLoginApi(Resource): class EmailCodeLoginApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") .add_argument("email", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
@ -220,18 +253,46 @@ class EmailCodeLoginApi(Resource):
raise WorkspacesLimitExceeded() raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
# Set HTTP-only secure cookies for tokens
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
return response
@console_ns.route("/refresh-token") @console_ns.route("/refresh-token")
class RefreshTokenApi(Resource): class RefreshTokenApi(Resource):
def post(self): def post(self):
parser = reqparse.RequestParser() # Get refresh token from cookie instead of request body
parser.add_argument("refresh_token", type=str, required=True, location="json") refresh_token = request.cookies.get("refresh_token")
args = parser.parse_args()
if not refresh_token:
return {"result": "fail", "message": "No refresh token provided"}, 401
try: try:
new_token_pair = AccountService.refresh_token(args["refresh_token"]) new_token_pair = AccountService.refresh_token(refresh_token)
return {"result": "success", "data": new_token_pair.model_dump()}
# Create response with new cookies
response = make_response({"result": "success"})
# Update cookies with new tokens
set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token)
set_access_token_to_cookie(request, response, new_token_pair.access_token)
set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token)
return response
except Exception as e: except Exception as e:
return {"result": "fail", "data": str(e)}, 401 return {"result": "fail", "message": str(e)}, 401
# this api helps frontend to check whether user is authenticated
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
@console_ns.route("/login/status")
class LoginStatus(Resource):
def get(self):
token = extract_access_token(request)
csrf_token = extract_csrf_token(request)
return {"logged_in": bool(token) and bool(csrf_token)}

View File

@ -14,8 +14,12 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account from libs.token import (
from models.account import AccountStatus set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from models import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.account import AccountNotFoundError, AccountRegisterError
@ -153,9 +157,12 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request), ip_address=extract_remote_ip(request),
) )
return redirect( response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
) set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None: def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:

View File

@ -1,16 +1,15 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast from typing import Concatenate, ParamSpec, TypeVar
import flask_login
from flask import jsonify, request from flask import jsonify, request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account from models import Account
from models.model import OAuthProviderApp from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
@ -24,8 +23,7 @@ T = TypeVar("T")
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view) @wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs): def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
client_id = parsed_args.get("client_id") client_id = parsed_args.get("client_id")
if not client_id: if not client_id:
@ -91,8 +89,7 @@ class OAuthServerAppApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
parser.add_argument("redirect_uri", type=str, required=True, location="json")
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
redirect_uri = parsed_args.get("redirect_uri") redirect_uri = parsed_args.get("redirect_uri")
@ -116,7 +113,8 @@ class OAuthServerUserAuthorizeApi(Resource):
@account_initialization_required @account_initialization_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
account = cast(Account, flask_login.current_user) current_user, _ = current_account_with_tenant()
account = current_user
user_account_id = account.id user_account_id = account.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id) code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
@ -132,12 +130,14 @@ class OAuthServerUserTokenApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser() parser = (
parser.add_argument("grant_type", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=False, location="json") .add_argument("grant_type", type=str, required=True, location="json")
parser.add_argument("client_secret", type=str, required=False, location="json") .add_argument("code", type=str, required=False, location="json")
parser.add_argument("redirect_uri", type=str, required=False, location="json") .add_argument("client_secret", type=str, required=False, location="json")
parser.add_argument("refresh_token", type=str, required=False, location="json") .add_argument("redirect_uri", type=str, required=False, location="json")
.add_argument("refresh_token", type=str, required=False, location="json")
)
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
try: try:

View File

@ -2,8 +2,7 @@ from flask_restx import Resource, reqparse
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.model import Account
from services.billing_service import BillingService from services.billing_service import BillingService
@ -14,17 +13,15 @@ class Subscription(Resource):
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
parser = reqparse.RequestParser() current_user, current_tenant_id = current_account_with_tenant()
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) parser = (
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) reqparse.RequestParser()
args = parser.parse_args() .add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
assert isinstance(current_user, Account) .add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_subscription(
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
) )
args = parser.parse_args()
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
@console_ns.route("/billing/invoices") @console_ns.route("/billing/invoices")
@ -34,7 +31,6 @@ class Invoices(Resource):
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None return BillingService.get_invoices(current_user.email, current_tenant_id)
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)

View File

@ -2,8 +2,7 @@ from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from services.billing_service import BillingService from services.billing_service import BillingService
from .. import console_ns from .. import console_ns
@ -17,19 +16,16 @@ class ComplianceApi(Resource):
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
parser = reqparse.RequestParser()
parser.add_argument("doc_name", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device") device_info = request.headers.get("User-Agent", "Unknown device")
return BillingService.get_compliance_download_link( return BillingService.get_compliance_download_link(
doc_name=args.doc_name, doc_name=args.doc_name,
account_id=current_user.id, account_id=current_user.id,
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
ip=ip_address, ip=ip_address,
device_info=device_info, device_info=device_info,
) )

View File

@ -3,7 +3,6 @@ from collections.abc import Generator
from typing import cast from typing import cast
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -20,7 +19,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
@ -37,10 +36,12 @@ class DataSourceApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(integrate_list_fields) @marshal_with(integrate_list_fields)
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
# get workspace data source integrates # get workspace data source integrates
data_source_integrates = db.session.scalars( data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where( select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_tenant_id,
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
) )
).all() ).all()
@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(integrate_notion_info_list_fields) @marshal_with(integrate_notion_info_list_fields)
def get(self): def get(self):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = request.args.get("dataset_id", default=None, type=str) dataset_id = request.args.get("dataset_id", default=None, type=str)
credential_id = request.args.get("credential_id", default=None, type=str) credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id: if not credential_id:
raise ValueError("Credential id is required.") raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials( credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
credential_id=credential_id, credential_id=credential_id,
provider="notion_datasource", provider="notion_datasource",
plugin_id="langgenius/notion_datasource", plugin_id="langgenius/notion_datasource",
@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource):
documents = session.scalars( documents = session.scalars(
select(Document).filter_by( select(Document).filter_by(
dataset_id=dataset_id, dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
data_source_type="notion_import", data_source_type="notion_import",
enabled=True, enabled=True,
) )
@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource):
datasource_runtime = DatasourceManager.get_datasource_runtime( datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion_datasource", provider_id="langgenius/notion_datasource/notion_datasource",
datasource_name="notion_datasource", datasource_name="notion_datasource",
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
) )
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, workspace_id, page_id, page_type): def get(self, workspace_id, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str) credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id: if not credential_id:
raise ValueError("Credential id is required.") raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials( credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
credential_id=credential_id, credential_id=credential_id,
provider="notion_datasource", provider="notion_datasource",
plugin_id="langgenius/notion_datasource", plugin_id="langgenius/notion_datasource",
@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource):
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type, notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"), notion_access_token=credential.get("integration_secret"),
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
) )
text_docs = extractor.extract() text_docs = extractor.extract()
@ -239,12 +244,14 @@ class DataSourceNotionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() _, current_tenant_id = current_account_with_tenant()
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") parser = (
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
"doc_language", type=str, default="English", required=False, nullable=False, location="json" .add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
) )
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
@ -263,7 +270,7 @@ class DataSourceNotionApi(Resource):
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"], "notion_obj_id": page["page_id"],
"notion_page_type": page["type"], "notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id, "tenant_id": current_tenant_id,
} }
), ),
document_model=args["doc_form"], document_model=args["doc_form"],
@ -271,7 +278,7 @@ class DataSourceNotionApi(Resource):
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate( response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_tenant_id,
extract_settings, extract_settings,
args["process_rule"], args["process_rule"],
args["doc_form"], args["doc_form"],

View File

@ -1,7 +1,6 @@
from typing import Any, cast from typing import Any, cast
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -30,10 +29,9 @@ from extensions.ext_database import db
from fields.app_fields import related_app_list from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields from fields.document_fields import document_status_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -138,6 +136,7 @@ class DatasetListApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def get(self): def get(self):
current_user, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids") ids = request.args.getlist("ids")
@ -146,15 +145,15 @@ class DatasetListApi(Resource):
tag_ids = request.args.getlist("tag_ids") tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true" include_all = request.args.get("include_all", default="false").lower() == "true"
if ids: if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
else: else:
datasets, total = DatasetService.get_datasets( datasets, total = DatasetService.get_datasets(
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all page, limit, current_tenant_id, current_user, search, tag_ids, include_all
) )
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@ -207,50 +206,53 @@ class DatasetListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="type is required. Name must be between 1 to 40 characters.", required=True,
type=_validate_name, help="type is required. Name must be between 1 to 40 characters.",
) type=_validate_name,
parser.add_argument( )
"description", .add_argument(
type=validate_description_length, "description",
nullable=True, type=validate_description_length,
required=False, nullable=True,
default="", required=False,
) default="",
parser.add_argument( )
"indexing_technique", .add_argument(
type=str, "indexing_technique",
location="json", type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST, location="json",
nullable=True, choices=Dataset.INDEXING_TECHNIQUE_LIST,
help="Invalid indexing technique.", nullable=True,
) help="Invalid indexing technique.",
parser.add_argument( )
"external_knowledge_api_id", .add_argument(
type=str, "external_knowledge_api_id",
nullable=True, type=str,
required=False, nullable=True,
) required=False,
parser.add_argument( )
"provider", .add_argument(
type=str, "provider",
nullable=True, type=str,
choices=Dataset.PROVIDER_LIST, nullable=True,
required=False, choices=Dataset.PROVIDER_LIST,
default="vendor", required=False,
) default="vendor",
parser.add_argument( )
"external_knowledge_id", .add_argument(
type=str, "external_knowledge_id",
nullable=True, type=str,
required=False, nullable=True,
required=False,
)
) )
args = parser.parse_args() args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@ -258,11 +260,11 @@ class DatasetListApi(Resource):
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
name=args["name"], name=args["name"],
description=args["description"], description=args["description"],
indexing_technique=args["indexing_technique"], indexing_technique=args["indexing_technique"],
account=cast(Account, current_user), account=current_user,
permission=DatasetPermissionEnum.ONLY_ME, permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"], provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"], external_knowledge_api_id=args["external_knowledge_api_id"],
@ -286,6 +288,7 @@ class DatasetApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -305,7 +308,7 @@ class DatasetApi(Resource):
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@ -351,73 +354,76 @@ class DatasetApi(Resource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
help="type is required. Name must be between 1 to 40 characters.", nullable=False,
type=_validate_name, help="type is required. Name must be between 1 to 40 characters.",
) type=_validate_name,
parser.add_argument("description", location="json", store_missing=False, type=validate_description_length) )
parser.add_argument( .add_argument("description", location="json", store_missing=False, type=validate_description_length)
"indexing_technique", .add_argument(
type=str, "indexing_technique",
location="json", type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST, location="json",
nullable=True, choices=Dataset.INDEXING_TECHNIQUE_LIST,
help="Invalid indexing technique.", nullable=True,
) help="Invalid indexing technique.",
parser.add_argument( )
"permission", .add_argument(
type=str, "permission",
location="json", type=str,
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), location="json",
help="Invalid permission.", choices=(
) DatasetPermissionEnum.ONLY_ME,
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") DatasetPermissionEnum.ALL_TEAM,
parser.add_argument( DatasetPermissionEnum.PARTIAL_TEAM,
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." ),
) help="Invalid permission.",
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") )
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") .add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument(
parser.add_argument( "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
"external_retrieval_model", )
type=dict, .add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
required=False, .add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
nullable=True, .add_argument(
location="json", "external_retrieval_model",
help="Invalid external retrieval model.", type=dict,
) required=False,
nullable=True,
parser.add_argument( location="json",
"external_knowledge_id", help="Invalid external retrieval model.",
type=str, )
required=False, .add_argument(
nullable=True, "external_knowledge_id",
location="json", type=str,
help="Invalid external knowledge id.", required=False,
) nullable=True,
location="json",
parser.add_argument( help="Invalid external knowledge id.",
"external_knowledge_api_id", )
type=str, .add_argument(
required=False, "external_knowledge_api_id",
nullable=True, type=str,
location="json", required=False,
help="Invalid external knowledge api id.", nullable=True,
) location="json",
help="Invalid external knowledge api id.",
parser.add_argument( )
"icon_info", .add_argument(
type=dict, "icon_info",
required=False, type=dict,
nullable=True, required=False,
location="json", nullable=True,
help="Invalid icon info.", location="json",
help="Invalid icon info.",
)
) )
args = parser.parse_args() args = parser.parse_args()
data = request.get_json() data = request.get_json()
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting # check embedding model setting
if ( if (
@ -440,7 +446,7 @@ class DatasetApi(Resource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_user.current_tenant_id tenant_id = current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members": if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list( DatasetPermissionService.update_partial_member_list(
@ -464,9 +470,9 @@ class DatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id): def delete(self, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor if not (current_user.has_edit_permission or current_user.is_dataset_operator):
if not (current_user.is_editor or current_user.is_dataset_operator):
raise Forbidden() raise Forbidden()
try: try:
@ -505,6 +511,7 @@ class DatasetQueryApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -539,32 +546,31 @@ class DatasetIndexingEstimateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json") reqparse.RequestParser()
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") .add_argument("info_list", type=dict, required=True, nullable=True, location="json")
parser.add_argument( .add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
"indexing_technique", .add_argument(
type=str, "indexing_technique",
required=True, type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST, required=True,
nullable=True, choices=Dataset.INDEXING_TECHNIQUE_LIST,
location="json", nullable=True,
) location="json",
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") )
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json") .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument( .add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
"doc_language", type=str, default="English", required=False, nullable=False, location="json" .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
) )
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
extract_settings = [] extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file": if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"] file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = db.session.scalars( file_details = db.session.scalars(
select(UploadFile).where( select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
)
).all() ).all()
if file_details is None: if file_details is None:
@ -592,7 +598,7 @@ class DatasetIndexingEstimateApi(Resource):
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"], "notion_obj_id": page["page_id"],
"notion_page_type": page["type"], "notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id, "tenant_id": current_tenant_id,
} }
), ),
document_model=args["doc_form"], document_model=args["doc_form"],
@ -608,7 +614,7 @@ class DatasetIndexingEstimateApi(Resource):
"provider": website_info_list["provider"], "provider": website_info_list["provider"],
"job_id": website_info_list["job_id"], "job_id": website_info_list["job_id"],
"url": url, "url": url,
"tenant_id": current_user.current_tenant_id, "tenant_id": current_tenant_id,
"mode": "crawl", "mode": "crawl",
"only_main_content": website_info_list["only_main_content"], "only_main_content": website_info_list["only_main_content"],
} }
@ -621,7 +627,7 @@ class DatasetIndexingEstimateApi(Resource):
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.indexing_estimate( response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_tenant_id,
extract_settings, extract_settings,
args["process_rule"], args["process_rule"],
args["doc_form"], args["doc_form"],
@ -652,6 +658,7 @@ class DatasetRelatedAppListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(related_app_list) @marshal_with(related_app_list)
def get(self, dataset_id): def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -683,11 +690,10 @@ class DatasetIndexingStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
documents = db.session.scalars( documents = db.session.scalars(
select(Document).where( select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
)
).all() ).all()
documents_status = [] documents_status = []
for document in documents: for document in documents:
@ -739,10 +745,9 @@ class DatasetApiKeyApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(api_key_list) @marshal_with(api_key_list)
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars( keys = db.session.scalars(
select(ApiToken).where( select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
)
).all() ).all()
return {"items": keys} return {"items": keys}
@ -752,12 +757,13 @@ class DatasetApiKeyApi(Resource):
@marshal_with(api_key_fields) @marshal_with(api_key_fields)
def post(self): def post(self):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
.count() .count()
) )
@ -770,7 +776,7 @@ class DatasetApiKeyApi(Resource):
key = ApiToken.generate_api_key(self.token_prefix, 24) key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken() api_token = ApiToken()
api_token.tenant_id = current_user.current_tenant_id api_token.tenant_id = current_tenant_id
api_token.token = key api_token.token = key
api_token.type = self.resource_type api_token.type = self.resource_type
db.session.add(api_token) db.session.add(api_token)
@ -790,6 +796,7 @@ class DatasetApiDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, api_key_id): def delete(self, api_key_id):
current_user, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id) api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
@ -799,7 +806,7 @@ class DatasetApiDeleteApi(Resource):
key = ( key = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where( .where(
ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id, ApiToken.id == api_key_id,
) )
@ -898,6 +905,7 @@ class DatasetPermissionUserListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:

View File

@ -6,7 +6,6 @@ from typing import Literal, cast
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc, select from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -53,9 +52,8 @@ from fields.document_fields import (
document_with_segments_fields, document_with_segments_fields,
) )
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DocumentPipelineExecutionLog from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
@ -65,6 +63,7 @@ logger = logging.getLogger(__name__)
class DocumentResource(Resource): class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document: def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -79,12 +78,13 @@ class DocumentResource(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id: if document.tenant_id != current_tenant_id:
raise Forbidden("No permission.") raise Forbidden("No permission.")
return document return document
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]: def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
current_user, _ = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -112,6 +112,7 @@ class GetProcessRuleApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
current_user, _ = current_account_with_tenant()
req_data = request.args req_data = request.args
document_id = req_data.get("document_id") document_id = req_data.get("document_id")
@ -168,6 +169,7 @@ class DatasetDocumentListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
@ -199,7 +201,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
if search: if search:
search = f"%{search}%" search = f"%{search}%"
@ -273,6 +275,7 @@ class DatasetDocumentListApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -289,20 +292,20 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" .add_argument(
) "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
parser.add_argument("data_source", type=dict, required=False, location="json") )
parser.add_argument("process_rule", type=dict, required=False, location="json") .add_argument("data_source", type=dict, required=False, location="json")
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") .add_argument("process_rule", type=dict, required=False, location="json")
parser.add_argument("original_document_id", type=str, required=False, location="json") .add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") .add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") .add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument( .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
"doc_language", type=str, default="English", required=False, nullable=False, location="json" .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
) )
args = parser.parse_args() args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args) knowledge_config = KnowledgeConfig.model_validate(args)
@ -372,27 +375,28 @@ class DatasetInitApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"indexing_technique", .add_argument(
type=str, "indexing_technique",
choices=Dataset.INDEXING_TECHNIQUE_LIST, type=str,
required=True, choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=False, required=True,
location="json", nullable=False,
location="json",
)
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
) )
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args) knowledge_config = KnowledgeConfig.model_validate(args)
@ -402,7 +406,7 @@ class DatasetInitApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=args["embedding_model_provider"], provider=args["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=args["embedding_model"], model=args["embedding_model"],
@ -419,9 +423,9 @@ class DatasetInitApi(Resource):
try: try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id( dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
knowledge_config=knowledge_config, knowledge_config=knowledge_config,
account=cast(Account, current_user), account=current_user,
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -447,6 +451,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id, document_id): def get(self, dataset_id, document_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
@ -482,7 +487,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
try: try:
estimate_response = indexing_runner.indexing_estimate( estimate_response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_tenant_id,
[extract_setting], [extract_setting],
data_process_rule_dict, data_process_rule_dict,
document.doc_form, document.doc_form,
@ -511,6 +516,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id, batch): def get(self, dataset_id, batch):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
batch = str(batch) batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch) documents = self.get_batch_documents(dataset_id, batch)
@ -530,7 +536,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first() .first()
) )
@ -553,7 +559,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
"notion_workspace_id": data_source_info["notion_workspace_id"], "notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"], "notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"], "notion_page_type": data_source_info["type"],
"tenant_id": current_user.current_tenant_id, "tenant_id": current_tenant_id,
} }
), ),
document_model=document.doc_form, document_model=document.doc_form,
@ -569,7 +575,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
"provider": data_source_info["provider"], "provider": data_source_info["provider"],
"job_id": data_source_info["job_id"], "job_id": data_source_info["job_id"],
"url": data_source_info["url"], "url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id, "tenant_id": current_tenant_id,
"mode": data_source_info["mode"], "mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"], "only_main_content": data_source_info["only_main_content"],
} }
@ -583,7 +589,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.indexing_estimate( response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_tenant_id,
extract_settings, extract_settings,
data_process_rule_dict, data_process_rule_dict,
document.doc_form, document.doc_form,
@ -834,6 +840,7 @@ class DocumentProcessingApi(DocumentResource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]): def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
@ -884,6 +891,7 @@ class DocumentMetadataApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self, dataset_id, document_id): def put(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
@ -931,6 +939,7 @@ class DocumentStatusApi(DocumentResource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if dataset is None: if dataset is None:
@ -1034,8 +1043,9 @@ class DocumentRetryApi(DocumentResource):
def post(self, dataset_id): def post(self, dataset_id):
"""retry document.""" """retry document."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json") "document_ids", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -1077,14 +1087,14 @@ class DocumentRenameApi(DocumentResource):
@marshal_with(document_fields) @marshal_with(document_fields)
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset) DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -1102,6 +1112,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
@account_initialization_required @account_initialization_required
def get(self, dataset_id, document_id): def get(self, dataset_id, document_id):
"""sync website document.""" """sync website document."""
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
@ -1110,7 +1121,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
document = DocumentService.get_document(dataset.id, document_id) document = DocumentService.get_document(dataset.id, document_id)
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id: if document.tenant_id != current_tenant_id:
raise Forbidden("No permission.") raise Forbidden("No permission.")
if document.data_source_type != "website_crawl": if document.data_source_type != "website_crawl":
raise ValueError("Document is not a website document.") raise ValueError("Document is not a website document.")

View File

@ -1,7 +1,6 @@
import uuid import uuid
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse from flask_restx import Resource, marshal, reqparse
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id, document_id): def get(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -59,13 +60,15 @@ class DatasetDocumentSegmentListApi(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = reqparse.RequestParser() parser = (
parser.add_argument("limit", type=int, default=20, location="args") reqparse.RequestParser()
parser.add_argument("status", type=str, action="append", default=[], location="args") .add_argument("limit", type=int, default=20, location="args")
parser.add_argument("hit_count_gte", type=int, default=None, location="args") .add_argument("status", type=str, action="append", default=[], location="args")
parser.add_argument("enabled", type=str, default="all", location="args") .add_argument("hit_count_gte", type=int, default=None, location="args")
parser.add_argument("keyword", type=str, default=None, location="args") .add_argument("enabled", type=str, default="all", location="args")
parser.add_argument("page", type=int, default=1, location="args") .add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
args = parser.parse_args() args = parser.parse_args()
@ -79,7 +82,7 @@ class DatasetDocumentSegmentListApi(Resource):
select(DocumentSegment) select(DocumentSegment)
.where( .where(
DocumentSegment.document_id == str(document_id), DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id, DocumentSegment.tenant_id == current_tenant_id,
) )
.order_by(DocumentSegment.position.asc()) .order_by(DocumentSegment.position.asc())
) )
@ -115,6 +118,8 @@ class DatasetDocumentSegmentListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id): def delete(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -148,6 +153,8 @@ class DatasetDocumentSegmentApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action): def patch(self, dataset_id, document_id, action):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
@ -171,7 +178,7 @@ class DatasetDocumentSegmentApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -204,6 +211,8 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -221,7 +230,7 @@ class DatasetDocumentSegmentAddApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -237,10 +246,12 @@ class DatasetDocumentSegmentAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = (
parser.add_argument("content", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("answer", type=str, required=False, nullable=True, location="json") .add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") .add_argument("answer", type=str, required=False, nullable=True, location="json")
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset) segment = SegmentService.create_segment(args, document, dataset)
@ -255,6 +266,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id): def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -272,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -287,7 +300,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -300,12 +313,14 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = (
parser.add_argument("content", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("answer", type=str, required=False, nullable=True, location="json") .add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") .add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument( .add_argument("keywords", type=list, required=False, nullable=True, location="json")
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json" .add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
) )
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
@ -317,6 +332,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id): def delete(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -333,7 +350,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -361,6 +378,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -372,8 +391,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json") "upload_file_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
upload_file_id = args["upload_file_id"] upload_file_id = args["upload_file_id"]
@ -396,7 +416,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
upload_file_id, upload_file_id,
dataset_id, dataset_id,
document_id, document_id,
current_user.current_tenant_id, current_tenant_id,
current_user.id, current_user.id,
) )
except Exception as e: except Exception as e:
@ -427,6 +447,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id, segment_id): def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -441,7 +463,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -453,7 +475,7 @@ class ChildChunkAddApi(Resource):
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
@ -469,8 +491,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("content", type=str, required=True, nullable=False, location="json") "content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
content = args["content"] content = args["content"]
@ -483,6 +506,8 @@ class ChildChunkAddApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id, document_id, segment_id): def get(self, dataset_id, document_id, segment_id):
_, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -499,15 +524,17 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
parser = reqparse.RequestParser() parser = (
parser.add_argument("limit", type=int, default=20, location="args") reqparse.RequestParser()
parser.add_argument("keyword", type=str, default=None, location="args") .add_argument("limit", type=int, default=20, location="args")
parser.add_argument("page", type=int, default=1, location="args") .add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
args = parser.parse_args() args = parser.parse_args()
@ -530,6 +557,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id): def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -546,7 +575,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -559,8 +588,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") "chunks", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
chunks_data = args["chunks"] chunks_data = args["chunks"]
@ -580,6 +610,8 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id, child_chunk_id): def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -596,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -607,7 +639,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk) db.session.query(ChildChunk)
.where( .where(
ChildChunk.id == str(child_chunk_id), ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id, ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id, ChildChunk.document_id == document_id,
) )
@ -634,6 +666,8 @@ class ChildChunkUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id): def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -650,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -661,7 +695,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk) db.session.query(ChildChunk)
.where( .where(
ChildChunk.id == str(child_chunk_id), ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id, ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id, ChildChunk.document_id == document_id,
) )
@ -677,8 +711,9 @@ class ChildChunkUpdateApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("content", type=str, required=True, nullable=False, location="json") "content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
content = args["content"] content = args["content"]

View File

@ -1,7 +1,4 @@
from typing import cast
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, reqparse from flask_restx import Resource, fields, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -10,8 +7,7 @@ from controllers.console import api, console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
@ -40,12 +36,13 @@ class ExternalApiTemplateListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str) search = request.args.get("keyword", default=None, type=str)
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis( external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
page, limit, current_user.current_tenant_id, search page, limit, current_tenant_id, search
) )
response = { response = {
"data": [item.to_dict() for item in external_knowledge_apis], "data": [item.to_dict() for item in external_knowledge_apis],
@ -60,20 +57,23 @@ class ExternalApiTemplateListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() current_user, current_tenant_id = current_account_with_tenant()
parser.add_argument( parser = (
"name", reqparse.RequestParser()
nullable=False, .add_argument(
required=True, "name",
help="Name is required. Name must be between 1 to 100 characters.", nullable=False,
type=_validate_name, required=True,
) help="Name is required. Name must be between 1 to 100 characters.",
parser.add_argument( type=_validate_name,
"settings", )
type=dict, .add_argument(
location="json", "settings",
nullable=False, type=dict,
required=True, location="json",
nullable=False,
required=True,
)
) )
args = parser.parse_args() args = parser.parse_args()
@ -85,7 +85,7 @@ class ExternalApiTemplateListApi(Resource):
try: try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api( external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args tenant_id=current_tenant_id, user_id=current_user.id, args=args
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@ -115,28 +115,31 @@ class ExternalApiTemplateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, external_knowledge_api_id): def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_id = str(external_knowledge_api_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="type is required. Name must be between 1 to 100 characters.", required=True,
type=_validate_name, help="type is required. Name must be between 1 to 100 characters.",
) type=_validate_name,
parser.add_argument( )
"settings", .add_argument(
type=dict, "settings",
location="json", type=dict,
nullable=False, location="json",
required=True, nullable=False,
required=True,
)
) )
args = parser.parse_args() args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"]) ExternalDatasetService.validate_api_list(args["settings"])
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api( external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
user_id=current_user.id, user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id, external_knowledge_api_id=external_knowledge_api_id,
args=args, args=args,
@ -148,13 +151,13 @@ class ExternalApiTemplateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, external_knowledge_api_id): def delete(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_id = str(external_knowledge_api_id)
# The role of the current user in the ta table must be admin, owner, or editor if not (current_user.has_edit_permission or current_user.is_dataset_operator):
if not (current_user.is_editor or current_user.is_dataset_operator):
raise Forbidden() raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id) ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
return {"result": "success"}, 204 return {"result": "success"}, 204
@ -199,21 +202,24 @@ class ExternalDatasetCreateApi(Resource):
@account_initialization_required @account_initialization_required
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json") .add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
parser.add_argument( .add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="name is required. Name must be between 1 to 100 characters.", required=True,
type=_validate_name, help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument("description", type=str, required=False, nullable=True, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
) )
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -223,7 +229,7 @@ class ExternalDatasetCreateApi(Resource):
try: try:
dataset = ExternalDatasetService.create_external_dataset( dataset = ExternalDatasetService.create_external_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
user_id=current_user.id, user_id=current_user.id,
args=args, args=args,
) )
@ -255,6 +261,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -265,10 +272,12 @@ class ExternalKnowledgeHitTestingApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
parser = reqparse.RequestParser() parser = (
parser.add_argument("query", type=str, location="json") reqparse.RequestParser()
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") .add_argument("query", type=str, location="json")
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json") .add_argument("external_retrieval_model", type=dict, required=False, location="json")
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
HitTestingService.hit_testing_args_check(args) HitTestingService.hit_testing_args_check(args)
@ -277,7 +286,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
response = HitTestingService.external_retrieve( response = HitTestingService.external_retrieve(
dataset=dataset, dataset=dataset,
query=args["query"], query=args["query"],
account=cast(Account, current_user), account=current_user,
external_retrieval_model=args["external_retrieval_model"], external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"], metadata_filtering_conditions=args["metadata_filtering_conditions"],
) )
@ -304,15 +313,17 @@ class BedrockRetrievalApi(Resource):
) )
@api.response(200, "Bedrock retrieval test completed") @api.response(200, "Bedrock retrieval test completed")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
"query", .add_argument(
nullable=False, "query",
required=True, nullable=False,
type=str, required=True,
type=str,
)
.add_argument("knowledge_id", nullable=False, required=True, type=str)
) )
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
args = parser.parse_args() args = parser.parse_args()
# Call the knowledge retrieval service # Call the knowledge retrieval service

View File

@ -48,11 +48,12 @@ class DatasetsHitTestingBase:
@staticmethod @staticmethod
def parse_args(): def parse_args():
parser = reqparse.RequestParser() parser = (
reqparse.RequestParser()
parser.add_argument("query", type=str, location="json") .add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json") .add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") .add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
return parser.parse_args() return parser.parse_args()
@staticmethod @staticmethod

View File

@ -1,13 +1,12 @@
from typing import Literal from typing import Literal
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields from fields.dataset_fields import dataset_metadata_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import ( from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs, MetadataArgs,
@ -24,9 +23,12 @@ class DatasetMetadataCreateApi(Resource):
@enterprise_license_required @enterprise_license_required
@marshal_with(dataset_metadata_fields) @marshal_with(dataset_metadata_fields)
def post(self, dataset_id): def post(self, dataset_id):
parser = reqparse.RequestParser() current_user, _ = current_account_with_tenant()
parser.add_argument("type", type=str, required=True, nullable=False, location="json") parser = (
parser.add_argument("name", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
metadata_args = MetadataArgs.model_validate(args) metadata_args = MetadataArgs.model_validate(args)
@ -59,8 +61,8 @@ class DatasetMetadataApi(Resource):
@enterprise_license_required @enterprise_license_required
@marshal_with(dataset_metadata_fields) @marshal_with(dataset_metadata_fields)
def patch(self, dataset_id, metadata_id): def patch(self, dataset_id, metadata_id):
parser = reqparse.RequestParser() current_user, _ = current_account_with_tenant()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
name = args["name"] name = args["name"]
@ -79,6 +81,7 @@ class DatasetMetadataApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def delete(self, dataset_id, metadata_id): def delete(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id) metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -108,6 +111,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def post(self, dataset_id, action: Literal["enable", "disable"]): def post(self, dataset_id, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
@ -128,14 +132,16 @@ class DocumentMetadataEditApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") "operation_data", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
metadata_args = MetadataOperationData.model_validate(args) metadata_args = MetadataOperationData.model_validate(args)

View File

@ -1,19 +1,15 @@
from flask import make_response, redirect, request from flask import make_response, redirect, request
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config from configs import dify_config
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
account_initialization_required,
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen from libs.helper import StrLen
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService from services.plugin.oauth_service import OAuthProxyService
@ -24,11 +20,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, provider_id: str): def get(self, provider_id: str):
user = current_user current_user, current_tenant_id = current_account_with_tenant()
tenant_id = user.current_tenant_id
if not current_user.is_editor: tenant_id = current_tenant_id
raise Forbidden()
credential_id = request.args.get("credential_id") credential_id = request.args.get("credential_id")
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
@ -52,7 +48,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback" redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
authorization_url_response = oauth_handler.get_authorization_url( authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user.id, user_id=current_user.id,
plugin_id=plugin_id, plugin_id=plugin_id,
provider=provider_name, provider=provider_name,
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
@ -130,22 +126,24 @@ class DatasourceAuth(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None .add_argument(
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
) )
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
try: try:
datasource_provider_service.add_datasource_api_key_provider( datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider_id=datasource_provider_id, provider_id=datasource_provider_id,
credentials=args["credentials"], credentials=args["credentials"],
name=args["name"], name=args["name"],
@ -160,8 +158,10 @@ class DatasourceAuth(Resource):
def get(self, provider_id: str): def get(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
_, current_tenant_id = current_account_with_tenant()
datasources = datasource_provider_service.list_datasource_credentials( datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=datasource_provider_id.provider_name, provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id, plugin_id=datasource_provider_id.plugin_id,
) )
@ -173,18 +173,21 @@ class DatasourceAuthDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name provider_name = datasource_provider_id.provider_name
if not current_user.is_editor:
raise Forbidden() parser = reqparse.RequestParser().add_argument(
parser = reqparse.RequestParser() "credential_id", type=str, required=True, nullable=False, location="json"
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") )
args = parser.parse_args() args = parser.parse_args()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials( datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
auth_id=args["credential_id"], auth_id=args["credential_id"],
provider=provider_name, provider=provider_name,
plugin_id=plugin_id, plugin_id=plugin_id,
@ -197,18 +200,22 @@ class DatasourceAuthUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") reqparse.RequestParser()
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json") .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") .add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials( datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
auth_id=args["credential_id"], auth_id=args["credential_id"],
provider=datasource_provider_id.provider_name, provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id, plugin_id=datasource_provider_id.plugin_id,
@ -224,10 +231,10 @@ class DatasourceAuthListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_all_datasource_credentials( datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
tenant_id=current_user.current_tenant_id
)
return {"result": jsonable_encoder(datasources)}, 200 return {"result": jsonable_encoder(datasources)}, 200
@ -237,10 +244,10 @@ class DatasourceHardCodeAuthListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_hard_code_datasource_credentials( datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
tenant_id=current_user.current_tenant_id
)
return {"result": jsonable_encoder(datasources)}, 200 return {"result": jsonable_encoder(datasources)}, 200
@ -249,17 +256,20 @@ class DatasourceAuthOauthCustomClient(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") reqparse.RequestParser()
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") .add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params( datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}), client_params=args.get("client_params", {}),
enabled=args.get("enable_oauth_custom_client", False), enabled=args.get("enable_oauth_custom_client", False),
@ -270,10 +280,12 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider_id: str): def delete(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params( datasource_provider_service.remove_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -284,16 +296,16 @@ class DatasourceAuthDefaultApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider( datasource_provider_service.set_default_datasource_provider(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
credential_id=args["id"], credential_id=args["id"],
) )
@ -305,17 +317,20 @@ class DatasourceUpdateProviderNameApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self, provider_id: str): def post(self, provider_id: str):
if not current_user.is_editor: _, current_tenant_id = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") .add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name( datasource_provider_service.update_datasource_provider_name(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
name=args["name"], name=args["name"],
credential_id=args["credential_id"], credential_id=args["credential_id"],

View File

@ -26,10 +26,12 @@ class DataSourceContentPreviewApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json") .add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
inputs = args.get("inputs") inputs = args.get("inputs")

View File

@ -66,26 +66,28 @@ class CustomizedPipelineTemplateApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def patch(self, template_id: str): def patch(self, template_id: str):
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="Name must be between 1 to 40 characters.", required=True,
type=_validate_name, help="Name must be between 1 to 40 characters.",
) type=_validate_name,
parser.add_argument( )
"description", .add_argument(
type=_validate_description_length, "description",
nullable=True, type=_validate_description_length,
required=False, nullable=True,
default="", required=False,
) default="",
parser.add_argument( )
"icon_info", .add_argument(
type=dict, "icon_info",
location="json", type=dict,
nullable=True, location="json",
nullable=True,
)
) )
args = parser.parse_args() args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args) pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
@ -123,26 +125,28 @@ class PublishCustomizedPipelineTemplateApi(Resource):
@enterprise_license_required @enterprise_license_required
@knowledge_pipeline_publish_enabled @knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str): def post(self, pipeline_id: str):
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", .add_argument(
nullable=False, "name",
required=True, nullable=False,
help="Name must be between 1 to 40 characters.", required=True,
type=_validate_name, help="Name must be between 1 to 40 characters.",
) type=_validate_name,
parser.add_argument( )
"description", .add_argument(
type=_validate_description_length, "description",
nullable=True, type=_validate_description_length,
required=False, nullable=True,
default="", required=False,
) default="",
parser.add_argument( )
"icon_info", .add_argument(
type=dict, "icon_info",
location="json", type=dict,
nullable=True, location="json",
nullable=True,
)
) )
args = parser.parse_args() args = parser.parse_args()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse from flask_restx import Resource, marshal, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -13,7 +12,7 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@ -27,9 +26,7 @@ class CreateRagPipelineDatasetApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"yaml_content", "yaml_content",
type=str, type=str,
nullable=False, nullable=False,
@ -38,7 +35,7 @@ class CreateRagPipelineDatasetApi(Resource):
) )
args = parser.parse_args() args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
@ -58,12 +55,12 @@ class CreateRagPipelineDatasetApi(Resource):
with Session(db.engine) as session: with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session) rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset( import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
) )
if rag_pipeline_dataset_create_entity.permission == "partial_members": if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list( DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id, current_tenant_id,
import_info["dataset_id"], import_info["dataset_id"],
rag_pipeline_dataset_create_entity.partial_member_list, rag_pipeline_dataset_create_entity.partial_member_list,
) )
@ -81,10 +78,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset( dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity( rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
name="", name="",
description="", description="",

View File

@ -23,7 +23,7 @@ from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models.account import Account from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from models.workflow import WorkflowDraftVariable from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -33,16 +33,18 @@ logger = logging.getLogger(__name__)
def _create_pagination_parser(): def _create_pagination_parser():
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"page", .add_argument(
type=inputs.int_range(1, 100_000), "page",
required=False, type=inputs.int_range(1, 100_000),
default=1, required=False,
location="args", default=1,
help="the page of data requested", location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
) )
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser return parser
@ -206,10 +208,11 @@ class RagPipelineVariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# } # }
parser = reqparse.RequestParser() parser = (
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") reqparse.RequestParser()
# Parse 'value' field as-is to maintain its original data structure .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=db.session(), session=db.session(),

View File

@ -1,6 +1,3 @@
from typing import cast
from flask_login import current_user # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -13,8 +10,7 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@ -28,26 +24,29 @@ class RagPipelineImportApi(Resource):
@marshal_with(pipeline_import_fields) @marshal_with(pipeline_import_fields)
def post(self): def post(self):
# Check user role first # Check user role first
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("mode", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("yaml_content", type=str, location="json") .add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_url", type=str, location="json") .add_argument("yaml_content", type=str, location="json")
parser.add_argument("name", type=str, location="json") .add_argument("yaml_url", type=str, location="json")
parser.add_argument("description", type=str, location="json") .add_argument("name", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json") .add_argument("description", type=str, location="json")
parser.add_argument("icon", type=str, location="json") .add_argument("icon_type", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json") .add_argument("icon", type=str, location="json")
parser.add_argument("pipeline_id", type=str, location="json") .add_argument("icon_background", type=str, location="json")
.add_argument("pipeline_id", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = RagPipelineDslService(session) import_service = RagPipelineDslService(session)
# Import app # Import app
account = cast(Account, current_user) account = current_user
result = import_service.import_rag_pipeline( result = import_service.import_rag_pipeline(
account=account, account=account,
import_mode=args["mode"], import_mode=args["mode"],
@ -74,15 +73,16 @@ class RagPipelineImportConfirmApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(pipeline_import_fields) @marshal_with(pipeline_import_fields)
def post(self, import_id): def post(self, import_id):
current_user, _ = current_account_with_tenant()
# Check user role first # Check user role first
if not current_user.is_editor: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = RagPipelineDslService(session) import_service = RagPipelineDslService(session)
# Confirm import # Confirm import
account = cast(Account, current_user) account = current_user
result = import_service.confirm_import(import_id=import_id, account=account) result = import_service.confirm_import(import_id=import_id, account=account)
session.commit() session.commit()
@ -100,7 +100,8 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(pipeline_import_check_dependencies_fields) @marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
with Session(db.engine) as session: with Session(db.engine) as session:
@ -117,12 +118,12 @@ class RagPipelineExportApi(Resource):
@get_rag_pipeline @get_rag_pipeline
@account_initialization_required @account_initialization_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
if not current_user.is_editor: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
# Add include_secret params # Add include_secret params
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
parser.add_argument("include_secret", type=str, default="false", location="args")
args = parser.parse_args() args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:

View File

@ -18,6 +18,7 @@ from controllers.console.app.error import (
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
edit_permission_required,
setup_required, setup_required,
) )
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
@ -36,8 +37,8 @@ from fields.workflow_run_fields import (
) )
from libs import helper from libs import helper
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, current_user, login_required
from models.account import Account from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from models.model import EndUser from models.model import EndUser
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
@ -56,15 +57,12 @@ class DraftRagPipelineApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get draft rag pipeline's workflow Get draft rag pipeline's workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model # fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
@ -79,23 +77,25 @@ class DraftRagPipelineApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline): def post(self, pipeline: Pipeline):
""" """
Sync draft workflow Sync draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
raise Forbidden()
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type: if "application/json" in content_type:
parser = reqparse.RequestParser() parser = (
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("hash", type=str, required=False, location="json") .add_argument("graph", type=dict, required=True, nullable=False, location="json")
parser.add_argument("environment_variables", type=list, required=False, location="json") .add_argument("hash", type=str, required=False, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json") .add_argument("environment_variables", type=list, required=False, location="json")
parser.add_argument("rag_pipeline_variables", type=list, required=False, location="json") .add_argument("conversation_variables", type=list, required=False, location="json")
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
@ -154,16 +154,15 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline, node_id: str): def post(self, pipeline: Pipeline, node_id: str):
""" """
Run draft workflow iteration node Run draft workflow iteration node
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -194,11 +193,11 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
Run draft workflow loop node Run draft workflow loop node
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -229,14 +228,17 @@ class DraftRagPipelineRunApi(Resource):
Run draft workflow Run draft workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_info_list", type=list, required=True, location="json") .add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json") .add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -264,17 +266,20 @@ class PublishedRagPipelineRunApi(Resource):
Run published workflow Run published workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_info_list", type=list, required=True, location="json") .add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json") .add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False) .add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming") .add_argument("is_preview", type=bool, required=True, location="json", default=False)
parser.add_argument("original_document_id", type=str, required=False, location="json") .add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
@ -303,15 +308,16 @@ class PublishedRagPipelineRunApi(Resource):
# Run rag pipeline datasource # Run rag pipeline datasource
# """ # """
# # The role of the current user in the ta table must be admin, owner, or editor # # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.is_editor: # if not current_user.has_edit_permission:
# raise Forbidden() # raise Forbidden()
# #
# if not isinstance(current_user, Account): # if not isinstance(current_user, Account):
# raise Forbidden() # raise Forbidden()
# #
# parser = reqparse.RequestParser() # parser = (reqparse.RequestParser()
# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") # .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# parser.add_argument("datasource_type", type=str, required=True, location="json") # .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args() # args = parser.parse_args()
# #
# job_id = args.get("job_id") # job_id = args.get("job_id")
@ -344,15 +350,16 @@ class PublishedRagPipelineRunApi(Resource):
# Run rag pipeline datasource # Run rag pipeline datasource
# """ # """
# # The role of the current user in the ta table must be admin, owner, or editor # # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.is_editor: # if not current_user.has_edit_permission:
# raise Forbidden() # raise Forbidden()
# #
# if not isinstance(current_user, Account): # if not isinstance(current_user, Account):
# raise Forbidden() # raise Forbidden()
# #
# parser = reqparse.RequestParser() # parser = (reqparse.RequestParser()
# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json") # .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# parser.add_argument("datasource_type", type=str, required=True, location="json") # .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args() # args = parser.parse_args()
# #
# job_id = args.get("job_id") # job_id = args.get("job_id")
@ -385,13 +392,16 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
Run rag pipeline datasource Run rag pipeline datasource
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json") .add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
inputs = args.get("inputs") inputs = args.get("inputs")
@ -428,13 +438,16 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
Run rag pipeline datasource Run rag pipeline datasource
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json") .add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
inputs = args.get("inputs") inputs = args.get("inputs")
@ -472,11 +485,13 @@ class RagPipelineDraftNodeRunApi(Resource):
Run draft workflow node Run draft workflow node
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") "inputs", type=dict, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
inputs = args.get("inputs") inputs = args.get("inputs")
@ -505,7 +520,8 @@ class RagPipelineTaskStopApi(Resource):
Stop workflow task Stop workflow task
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
@ -525,7 +541,8 @@ class PublishedRagPipelineApi(Resource):
Get published pipeline Get published pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
if not pipeline.is_published: if not pipeline.is_published:
return None return None
@ -545,7 +562,8 @@ class PublishedRagPipelineApi(Resource):
Publish workflow Publish workflow
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
@ -580,7 +598,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
Get default block config Get default block config
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
# Get default block configs # Get default block configs
@ -599,11 +618,11 @@ class DefaultRagPipelineBlockConfigApi(Resource):
Get default block config Get default block config
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
parser.add_argument("q", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
q = args.get("q") q = args.get("q")
@ -631,14 +650,17 @@ class PublishedAllRagPipelineApi(Resource):
""" """
Get published workflows Get published workflows
""" """
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
parser.add_argument("user_id", type=str, required=False, location="args") .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args") .add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
args = parser.parse_args() args = parser.parse_args()
page = int(args.get("page", 1)) page = int(args.get("page", 1))
limit = int(args.get("limit", 10)) limit = int(args.get("limit", 10))
@ -681,12 +703,15 @@ class RagPipelineByIdApi(Resource):
Update workflow attributes Update workflow attributes
""" """
# Check permission # Check permission
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("marked_name", type=str, required=False, location="json") reqparse.RequestParser()
parser.add_argument("marked_comment", type=str, required=False, location="json") .add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# Validate name and comment length # Validate name and comment length
@ -733,15 +758,12 @@ class PublishedRagPipelineSecondStepApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get second step parameters of rag pipeline Get second step parameters of rag pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
node_id = args.get("node_id") node_id = args.get("node_id")
if not node_id: if not node_id:
@ -759,15 +781,12 @@ class PublishedRagPipelineFirstStepApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get first step parameters of rag pipeline Get first step parameters of rag pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
node_id = args.get("node_id") node_id = args.get("node_id")
if not node_id: if not node_id:
@ -785,15 +804,12 @@ class DraftRagPipelineFirstStepApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get first step parameters of rag pipeline Get first step parameters of rag pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
node_id = args.get("node_id") node_id = args.get("node_id")
if not node_id: if not node_id:
@ -811,15 +827,12 @@ class DraftRagPipelineSecondStepApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get second step parameters of rag pipeline Get second step parameters of rag pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
node_id = args.get("node_id") node_id = args.get("node_id")
if not node_id: if not node_id:
@ -843,9 +856,11 @@ class RagPipelineWorkflowRunListApi(Resource):
""" """
Get workflow run list Get workflow run list
""" """
parser = reqparse.RequestParser() parser = (
parser.add_argument("last_id", type=uuid_value, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") .add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args() args = parser.parse_args()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
@ -880,7 +895,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@marshal_with(workflow_run_node_execution_list_fields) @marshal_with(workflow_run_node_execution_list_fields)
def get(self, pipeline: Pipeline, run_id): def get(self, pipeline: Pipeline, run_id: str):
""" """
Get workflow run node execution list Get workflow run node execution list
""" """
@ -903,14 +918,8 @@ class DatasourceListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user _, current_tenant_id = current_account_with_tenant()
if not isinstance(user, Account): return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(current_tenant_id))
raise Forbidden()
tenant_id = user.current_tenant_id
if not tenant_id:
raise Forbidden()
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
@ -940,9 +949,8 @@ class RagPipelineTransformApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, dataset_id): def post(self, dataset_id: str):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise Forbidden()
if not (current_user.has_edit_permission or current_user.is_dataset_operator): if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden() raise Forbidden()
@ -959,19 +967,20 @@ class RagPipelineDatasourceVariableApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_run_node_execution_fields) @marshal_with(workflow_run_node_execution_fields)
def post(self, pipeline: Pipeline): def post(self, pipeline: Pipeline):
""" """
Set datasource variables Set datasource variables
""" """
if not isinstance(current_user, Account) or not current_user.has_edit_permission: current_user, _ = current_account_with_tenant()
raise Forbidden() parser = (
reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json") .add_argument("datasource_info", type=dict, required=True, location="json")
parser.add_argument("datasource_info", type=dict, required=True, location="json") .add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json") .add_argument("start_node_title", type=str, required=True, location="json")
parser.add_argument("start_node_title", type=str, required=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()

View File

@ -31,17 +31,19 @@ class WebsiteCrawlApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"provider", .add_argument(
type=str, "provider",
choices=["firecrawl", "watercrawl", "jinareader"], type=str,
required=True, choices=["firecrawl", "watercrawl", "jinareader"],
nullable=True, required=True,
location="json", nullable=True,
location="json",
)
.add_argument("url", type=str, required=True, nullable=True, location="json")
.add_argument("options", type=dict, required=True, nullable=True, location="json")
) )
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
# Create typed request and validate # Create typed request and validate
@ -70,8 +72,7 @@ class WebsiteCrawlStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, job_id: str): def get(self, job_id: str):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args" "provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -3,8 +3,7 @@ from functools import wraps
from controllers.console.datasets.error import PipelineNotFoundError from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user from libs.login import current_account_with_tenant
from models.account import Account
from models.dataset import Pipeline from models.dataset import Pipeline
@ -17,8 +16,7 @@ def get_rag_pipeline(
if not kwargs.get("pipeline_id"): if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters") raise ValueError("missing pipeline_id in path parameters")
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("current_user is not an account")
pipeline_id = kwargs.get("pipeline_id") pipeline_id = kwargs.get("pipeline_id")
pipeline_id = str(pipeline_id) pipeline_id = str(pipeline_id)
@ -27,7 +25,7 @@ def get_rag_pipeline(
pipeline = ( pipeline = (
db.session.query(Pipeline) db.session.query(Pipeline)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id) .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
.first() .first()
) )

View File

@ -81,11 +81,13 @@ class ChatTextApi(InstalledAppResource):
app_model = installed_app.app app_model = installed_app.app
try: try:
parser = reqparse.RequestParser() parser = (
parser.add_argument("message_id", type=str, required=False, location="json") reqparse.RequestParser()
parser.add_argument("voice", type=str, location="json") .add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("text", type=str, location="json") .add_argument("voice", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json") .add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args() args = parser.parse_args()
message_id = args.get("message_id", None) message_id = args.get("message_id", None)

View File

@ -49,12 +49,14 @@ class CompletionApi(InstalledAppResource):
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, location="json") reqparse.RequestParser()
parser.add_argument("query", type=str, location="json", default="") .add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json") .add_argument("query", type=str, location="json", default="")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") .add_argument("files", type=list, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
@ -121,13 +123,15 @@ class ChatApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, location="json") reqparse.RequestParser()
parser.add_argument("query", type=str, required=True, location="json") .add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json") .add_argument("query", type=str, required=True, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json") .add_argument("files", type=list, required=False, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") .add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") .add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args() args = parser.parse_args()
args["auto_generate_name"] = False args["auto_generate_name"] = False

View File

@ -31,10 +31,12 @@ class ConversationListApi(InstalledAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("last_id", type=uuid_value, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") .add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
)
args = parser.parse_args() args = parser.parse_args()
pinned = None pinned = None
@ -94,9 +96,11 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id) conversation_id = str(c_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("name", type=str, required=False, location="json") reqparse.RequestParser()
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") .add_argument("name", type=str, required=False, location="json")
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:

View File

@ -12,10 +12,9 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
from extensions.ext_database import db from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields from fields.installed_app_fields import installed_app_list_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models import Account, App, InstalledApp, RecommendedApp from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -29,9 +28,7 @@ class InstalledAppsListApi(Resource):
@marshal_with(installed_app_list_fields) @marshal_with(installed_app_list_fields)
def get(self): def get(self):
app_id = request.args.get("app_id", default=None, type=str) app_id = request.args.get("app_id", default=None, type=str)
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
if app_id: if app_id:
installed_apps = db.session.scalars( installed_apps = db.session.scalars(
@ -69,31 +66,26 @@ class InstalledAppsListApi(Resource):
# Pre-filter out apps without setting or with sso_verified # Pre-filter out apps without setting or with sso_verified
filtered_installed_apps = [] filtered_installed_apps = []
app_id_to_app_code = {}
for installed_app in installed_app_list: for installed_app in installed_app_list:
app_id = installed_app["app"].id app_id = installed_app["app"].id
webapp_setting = webapp_settings.get(app_id) webapp_setting = webapp_settings.get(app_id)
if not webapp_setting or webapp_setting.access_mode == "sso_verified": if not webapp_setting or webapp_setting.access_mode == "sso_verified":
continue continue
app_code = AppService.get_app_code_by_id(str(app_id))
app_id_to_app_code[app_id] = app_code
filtered_installed_apps.append(installed_app) filtered_installed_apps.append(installed_app)
app_codes = list(app_id_to_app_code.values())
# Batch permission check # Batch permission check
app_ids = [installed_app["app"].id for installed_app in filtered_installed_apps]
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps( permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
user_id=user_id, user_id=user_id,
app_codes=app_codes, app_ids=app_ids,
) )
# Keep only allowed apps # Keep only allowed apps
res = [] res = []
for installed_app in filtered_installed_apps: for installed_app in filtered_installed_apps:
app_id = installed_app["app"].id app_id = installed_app["app"].id
app_code = app_id_to_app_code[app_id] if permissions.get(app_id):
if permissions.get(app_code):
res.append(installed_app) res.append(installed_app)
installed_app_list = res installed_app_list = res
@ -113,17 +105,15 @@ class InstalledAppsListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args() args = parser.parse_args()
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
if recommended_app is None: if recommended_app is None:
raise NotFound("App not found") raise NotFound("App not found")
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
app = db.session.query(App).where(App.id == args["app_id"]).first() app = db.session.query(App).where(App.id == args["app_id"]).first()
if app is None: if app is None:
@ -163,9 +153,8 @@ class InstalledAppApi(InstalledAppResource):
""" """
def delete(self, installed_app): def delete(self, installed_app):
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("current_user must be an Account instance") if installed_app.app_owner_tenant_id == current_tenant_id:
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant") raise BadRequest("You can't uninstall an app owned by the current tenant")
db.session.delete(installed_app) db.session.delete(installed_app)
@ -174,8 +163,7 @@ class InstalledAppApi(InstalledAppResource):
return {"result": "success", "message": "App uninstalled successfully"}, 204 return {"result": "success", "message": "App uninstalled successfully"}, 204
def patch(self, installed_app): def patch(self, installed_app):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
parser.add_argument("is_pinned", type=inputs.boolean)
args = parser.parse_args() args = parser.parse_args()
commit_args = False commit_args = False

View File

@ -23,8 +23,7 @@ from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import current_user from libs.login import current_account_with_tenant
from models import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError from services.errors.app import MoreLikeThisDisabledError
@ -48,21 +47,22 @@ logger = logging.getLogger(__name__)
class MessageListApi(InstalledAppResource): class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") reqparse.RequestParser()
parser.add_argument("first_id", type=uuid_value, location="args") .add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") .add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return MessageService.pagination_by_first_id( return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
) )
@ -78,18 +78,19 @@ class MessageListApi(InstalledAppResource):
) )
class MessageFeedbackApi(InstalledAppResource): class MessageFeedbackApi(InstalledAppResource):
def post(self, installed_app, message_id): def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") reqparse.RequestParser()
parser.add_argument("content", type=str, location="json") .add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
.add_argument("content", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
MessageService.create_feedback( MessageService.create_feedback(
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
@ -109,14 +110,14 @@ class MessageFeedbackApi(InstalledAppResource):
) )
class MessageMoreLikeThisApi(InstalledAppResource): class MessageMoreLikeThisApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
) )
args = parser.parse_args() args = parser.parse_args()
@ -124,8 +125,6 @@ class MessageMoreLikeThisApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate_more_like_this( response = AppGenerateService.generate_more_like_this(
app_model=app_model, app_model=app_model,
user=current_user, user=current_user,
@ -159,6 +158,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
) )
class MessageSuggestedQuestionApi(InstalledAppResource): class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -167,8 +167,6 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
message_id = str(message_id) message_id = str(message_id)
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer( questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
) )

View File

@ -42,8 +42,7 @@ class RecommendedAppListApi(Resource):
@marshal_with(recommended_app_list_fields) @marshal_with(recommended_app_list_fields)
def get(self): def get(self):
# language args # language args
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("language", type=str, location="args")
parser.add_argument("language", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
language = args.get("language") language = args.get("language")

View File

@ -7,8 +7,7 @@ from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from libs.login import current_user from libs.login import current_account_with_tenant
from models import Account
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService from services.saved_message_service import SavedMessageService
@ -35,31 +34,30 @@ class SavedMessageListApi(InstalledAppResource):
@marshal_with(saved_message_infinite_scroll_pagination_fields) @marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("last_id", type=uuid_value, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") .add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args() args = parser.parse_args()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
def post(self, installed_app): def post(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
parser.add_argument("message_id", type=uuid_value, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.save(app_model, current_user, args["message_id"]) SavedMessageService.save(app_model, current_user, args["message_id"])
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -72,6 +70,7 @@ class SavedMessageListApi(InstalledAppResource):
) )
class SavedMessageApi(InstalledAppResource): class SavedMessageApi(InstalledAppResource):
def delete(self, installed_app, message_id): def delete(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
message_id = str(message_id) message_id = str(message_id)
@ -79,8 +78,6 @@ class SavedMessageApi(InstalledAppResource):
if app_model.mode != "completion": if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.delete(app_model, current_user, message_id) SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -22,7 +22,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper from libs import helper
from libs.login import current_user from libs.login import current_user as current_user_
from models.model import AppMode, InstalledApp from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
@ -31,6 +31,8 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
current_user = current_user_._get_current_object() # type: ignore
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run") @console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource): class InstalledAppWorkflowRunApi(InstalledAppResource):
@ -45,9 +47,11 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("files", type=list, required=False, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
assert current_user is not None assert current_user is not None
try: try:

View File

@ -8,10 +8,8 @@ from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models import InstalledApp from models import InstalledApp
from models.account import Account
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -24,13 +22,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
def decorator(view: Callable[Concatenate[InstalledApp, P], R]): def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.where( .where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
)
.first() .first()
) )
@ -56,14 +51,13 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
def decorator(view: Callable[Concatenate[InstalledApp, P], R]): def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
feature = FeatureService.get_system_features() feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled: if feature.webapp_auth.enabled:
assert isinstance(current_user, Account)
app_id = installed_app.app_id app_id = installed_app.app_id
app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=str(current_user.id), user_id=str(current_user.id),
app_code=app_code, app_id=app_id,
) )
if not res: if not res:
raise AppAccessDeniedError() raise AppAccessDeniedError()

View File

@ -4,8 +4,7 @@ from constants import HIDDEN_VALUE
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from models.api_based_extension import APIBasedExtension from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService
@ -30,8 +29,7 @@ class CodeBasedExtensionAPI(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args")
parser.add_argument("module", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
@ -47,9 +45,7 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_fields) @marshal_with(api_based_extension_fields)
def get(self): def get(self):
assert isinstance(current_user, Account) _, tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
tenant_id = current_user.current_tenant_id
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@api.doc("create_api_based_extension") @api.doc("create_api_based_extension")
@ -70,16 +66,17 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_fields) @marshal_with(api_based_extension_fields)
def post(self): def post(self):
assert isinstance(current_user, Account) parser = (
assert current_user.current_tenant_id is not None reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("name", type=str, required=True, location="json")
parser.add_argument("name", type=str, required=True, location="json") .add_argument("api_endpoint", type=str, required=True, location="json")
parser.add_argument("api_endpoint", type=str, required=True, location="json") .add_argument("api_key", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension( extension_data = APIBasedExtension(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
name=args["name"], name=args["name"],
api_endpoint=args["api_endpoint"], api_endpoint=args["api_endpoint"],
api_key=args["api_key"], api_key=args["api_key"],
@ -99,10 +96,8 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_fields) @marshal_with(api_based_extension_fields)
def get(self, id): def get(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id) api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
@ -125,17 +120,17 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_fields) @marshal_with(api_based_extension_fields)
def post(self, id): def post(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id) api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id _, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("name", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("api_endpoint", type=str, required=True, location="json") .add_argument("name", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json") .add_argument("api_endpoint", type=str, required=True, location="json")
.add_argument("api_key", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
extension_data_from_db.name = args["name"] extension_data_from_db.name = args["name"]
@ -154,12 +149,10 @@ class APIBasedExtensionDetailAPI(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, id): def delete(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id) api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id _, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
APIBasedExtensionService.delete(extension_data_from_db) APIBasedExtensionService.delete(extension_data_from_db)

View File

@ -1,7 +1,6 @@
from flask_restx import Resource, fields from flask_restx import Resource, fields
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from services.feature_service import FeatureService from services.feature_service import FeatureService
from . import api, console_ns from . import api, console_ns
@ -23,9 +22,9 @@ class FeatureApi(Resource):
@cloud_utm_record @cloud_utm_record
def get(self): def get(self):
"""Get feature configuration for current tenant""" """Get feature configuration for current tenant"""
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
return FeatureService.get_features(current_user.current_tenant_id).model_dump() return FeatureService.get_features(current_tenant_id).model_dump()
@console_ns.route("/system-features") @console_ns.route("/system-features")

View File

@ -1,7 +1,6 @@
from typing import Literal from typing import Literal
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with from flask_restx import Resource, marshal_with
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -22,8 +21,7 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.file_fields import file_fields, upload_config_fields from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import Account
from services.file_service import FileService from services.file_service import FileService
from . import console_ns from . import console_ns
@ -53,6 +51,7 @@ class FileApi(Resource):
@marshal_with(file_fields) @marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents") @cloud_edition_billing_resource_check("documents")
def post(self): def post(self):
current_user, _ = current_account_with_tenant()
source_str = request.form.get("source") source_str = request.form.get("source")
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
@ -65,16 +64,12 @@ class FileApi(Resource):
if not file.filename: if not file.filename:
raise FilenameNotExistsError raise FilenameNotExistsError
if source == "datasets" and not current_user.is_dataset_editor: if source == "datasets" and not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
if source not in ("datasets", None): if source not in ("datasets", None):
source = None source = None
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
try: try:
upload_file = FileService(db.engine).upload_file( upload_file = FileService(db.engine).upload_file(
filename=file.filename, filename=file.filename,
@ -108,4 +103,4 @@ class FileSupportTypeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
return {"allowed_extensions": DOCUMENT_EXTENSIONS} return {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}

View File

@ -57,8 +57,7 @@ class InitValidateAPI(Resource):
if tenant_count > 0: if tenant_count > 0:
raise AlreadySetupError() raise AlreadySetupError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json")
parser.add_argument("password", type=StrLen(30), required=True, location="json")
input_password = parser.parse_args()["password"] input_password = parser.parse_args()["password"]
if input_password != os.environ.get("INIT_PASSWORD"): if input_password != os.environ.get("INIT_PASSWORD"):

View File

@ -14,8 +14,7 @@ from core.file import helpers as file_helpers
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db from extensions.ext_database import db
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from libs.login import current_user from libs.login import current_account_with_tenant
from models.account import Account
from services.file_service import FileService from services.file_service import FileService
from . import console_ns from . import console_ns
@ -41,8 +40,7 @@ class RemoteFileInfoApi(Resource):
class RemoteFileUploadApi(Resource): class RemoteFileUploadApi(Resource):
@marshal_with(file_fields_with_signed_url) @marshal_with(file_fields_with_signed_url)
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
parser.add_argument("url", type=str, required=True, help="URL is required")
args = parser.parse_args() args = parser.parse_args()
url = args["url"] url = args["url"]
@ -64,8 +62,7 @@ class RemoteFileUploadApi(Resource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try: try:
assert isinstance(current_user, Account) user, _ = current_account_with_tenant()
user = current_user
upload_file = FileService(db.engine).upload_file( upload_file = FileService(db.engine).upload_file(
filename=file_info.filename, filename=file_info.filename,
content=content, content=content,

View File

@ -69,10 +69,12 @@ class SetupApi(Resource):
if not get_init_validate_status(): if not get_init_validate_status():
raise NotInitValidateError() raise NotInitValidateError()
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("name", type=StrLen(30), required=True, location="json") .add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json") .add_argument("name", type=StrLen(30), required=True, location="json")
.add_argument("password", type=valid_password, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# setup # setup

View File

@ -5,8 +5,7 @@ from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields from fields.tag_fields import dataset_tag_fields
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from models.model import Tag from models.model import Tag
from services.tag_service import TagService from services.tag_service import TagService
@ -24,11 +23,10 @@ class TagListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(dataset_tag_fields) @marshal_with(dataset_tag_fields)
def get(self): def get(self):
assert isinstance(current_user, Account) _, current_tenant_id = current_account_with_tenant()
assert current_user.current_tenant_id is not None
tag_type = request.args.get("type", type=str, default="") tag_type = request.args.get("type", type=str, default="")
keyword = request.args.get("keyword", default=None, type=str) keyword = request.args.get("keyword", default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) tags = TagService.get_tags(tag_type, current_tenant_id, keyword)
return tags, 200 return tags, 200
@ -36,18 +34,23 @@ class TagListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
assert isinstance(current_user, Account) current_user, _ = current_account_with_tenant()
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name .add_argument(
) "name",
parser.add_argument( nullable=False,
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." required=True,
help="Name must be between 1 to 50 characters.",
type=_validate_name,
)
.add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
) )
args = parser.parse_args() args = parser.parse_args()
tag = TagService.save_tags(args) tag = TagService.save_tags(args)
@ -63,15 +66,13 @@ class TagUpdateDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, tag_id): def patch(self, tag_id):
assert isinstance(current_user, Account) current_user, _ = current_account_with_tenant()
assert current_user.current_tenant_id is not None
tag_id = str(tag_id) tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
) )
args = parser.parse_args() args = parser.parse_args()
@ -87,8 +88,7 @@ class TagUpdateDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, tag_id): def delete(self, tag_id):
assert isinstance(current_user, Account) current_user, _ = current_account_with_tenant()
assert current_user.current_tenant_id is not None
tag_id = str(tag_id) tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission: if not current_user.has_edit_permission:
@ -105,21 +105,22 @@ class TagBindingCreateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
assert isinstance(current_user, Account) current_user, _ = current_account_with_tenant()
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument( reqparse.RequestParser()
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." .add_argument(
) "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
parser.add_argument( )
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required." .add_argument(
) "target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
parser.add_argument( )
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." .add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
) )
args = parser.parse_args() args = parser.parse_args()
TagService.save_tag_binding(args) TagService.save_tag_binding(args)
@ -133,17 +134,18 @@ class TagBindingDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
assert isinstance(current_user, Account) current_user, _ = current_account_with_tenant()
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor): if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") reqparse.RequestParser()
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") .add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
parser.add_argument( .add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." .add_argument(
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
)
) )
args = parser.parse_args() args = parser.parse_args()
TagService.delete_tag_binding(args) TagService.delete_tag_binding(args)

View File

@ -37,8 +37,7 @@ class VersionApi(Resource):
) )
def get(self): def get(self):
"""Check for application version updates""" """Check for application version updates"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("current_version", type=str, required=True, location="args")
parser.add_argument("current_version", type=str, required=True, location="args")
args = parser.parse_args() args = parser.parse_args()
check_update_url = dify_config.CHECK_UPDATE_URL check_update_url = dify_config.CHECK_UPDATE_URL

View File

@ -2,11 +2,11 @@ from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar from typing import ParamSpec, TypeVar
from flask_login import current_user
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.account import TenantPluginPermission from models.account import TenantPluginPermission
P = ParamSpec("P") P = ParamSpec("P")
@ -20,8 +20,9 @@ def plugin_permission_required(
def interceptor(view: Callable[P, R]): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user user = current_user
tenant_id = user.current_tenant_id tenant_id = current_tenant_id
with Session(db.engine) as session: with Session(db.engine) as session:
permission = ( permission = (

View File

@ -2,7 +2,6 @@ from datetime import datetime
import pytz import pytz
from flask import request from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -37,9 +36,8 @@ from extensions.ext_database import db
from fields.member_fields import account_fields from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, email, extract_remote_ip, timezone from libs.helper import TimestampField, email, extract_remote_ip, timezone
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode from models import Account, AccountIntegrate, InvitationCode
from models.account import Account
from services.account_service import AccountService from services.account_service import AccountService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@ -50,9 +48,7 @@ class AccountInitApi(Resource):
@setup_required @setup_required
@login_required @login_required
def post(self): def post(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
if account.status == "active": if account.status == "active":
raise AccountAlreadyInitedError() raise AccountAlreadyInitedError()
@ -61,9 +57,9 @@ class AccountInitApi(Resource):
if dify_config.EDITION == "CLOUD": if dify_config.EDITION == "CLOUD":
parser.add_argument("invitation_code", type=str, location="json") parser.add_argument("invitation_code", type=str, location="json")
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
parser.add_argument("interface_language", type=supported_language, required=True, location="json") "timezone", type=timezone, required=True, location="json"
parser.add_argument("timezone", type=timezone, required=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
if dify_config.EDITION == "CLOUD": if dify_config.EDITION == "CLOUD":
@ -106,8 +102,7 @@ class AccountProfileApi(Resource):
@marshal_with(account_fields) @marshal_with(account_fields)
@enterprise_license_required @enterprise_license_required
def get(self): def get(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
return current_user return current_user
@ -118,10 +113,8 @@ class AccountNameApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
# Validate account name length # Validate account name length
@ -140,10 +133,8 @@ class AccountAvatarApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
parser = reqparse.RequestParser()
parser.add_argument("avatar", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
@ -158,10 +149,10 @@ class AccountInterfaceLanguageApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument(
parser = reqparse.RequestParser() "interface_language", type=supported_language, required=True, location="json"
parser.add_argument("interface_language", type=supported_language, required=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
@ -176,10 +167,10 @@ class AccountInterfaceThemeApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument(
parser = reqparse.RequestParser() "interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") )
args = parser.parse_args() args = parser.parse_args()
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
@ -194,10 +185,8 @@ class AccountTimezoneApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
parser = reqparse.RequestParser()
parser.add_argument("timezone", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
# Validate timezone string, e.g. America/New_York, Asia/Shanghai # Validate timezone string, e.g. America/New_York, Asia/Shanghai
@ -216,12 +205,13 @@ class AccountPasswordApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = (
parser = reqparse.RequestParser() reqparse.RequestParser()
parser.add_argument("password", type=str, required=False, location="json") .add_argument("password", type=str, required=False, location="json")
parser.add_argument("new_password", type=str, required=True, location="json") .add_argument("new_password", type=str, required=True, location="json")
parser.add_argument("repeat_new_password", type=str, required=True, location="json") .add_argument("repeat_new_password", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if args["new_password"] != args["repeat_new_password"]: if args["new_password"] != args["repeat_new_password"]:
@ -253,9 +243,7 @@ class AccountIntegrateApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(integrate_list_fields) @marshal_with(integrate_list_fields)
def get(self): def get(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
account_integrates = db.session.scalars( account_integrates = db.session.scalars(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id) select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
@ -298,9 +286,7 @@ class AccountDeleteVerifyApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
token, code = AccountService.generate_account_deletion_verification_code(account) token, code = AccountService.generate_account_deletion_verification_code(account)
AccountService.send_account_deletion_verification_email(account, code) AccountService.send_account_deletion_verification_email(account, code)
@ -314,13 +300,13 @@ class AccountDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
parser = reqparse.RequestParser() parser = (
parser.add_argument("token", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") .add_argument("token", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if not AccountService.verify_account_deletion_code(args["token"], args["code"]): if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
@ -335,9 +321,11 @@ class AccountDeleteApi(Resource):
class AccountDeleteUpdateFeedbackApi(Resource): class AccountDeleteUpdateFeedbackApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("feedback", type=str, required=True, location="json") .add_argument("email", type=str, required=True, location="json")
.add_argument("feedback", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
BillingService.update_account_deletion_feedback(args["email"], args["feedback"]) BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
@ -358,9 +346,7 @@ class EducationVerifyApi(Resource):
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
@marshal_with(verify_fields) @marshal_with(verify_fields)
def get(self): def get(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
return BillingService.EducationIdentity.verify(account.id, account.email) return BillingService.EducationIdentity.verify(account.id, account.email)
@ -380,14 +366,14 @@ class EducationApi(Resource):
@only_edition_cloud @only_edition_cloud
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
def post(self): def post(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
parser = reqparse.RequestParser() parser = (
parser.add_argument("token", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("institution", type=str, required=True, location="json") .add_argument("token", type=str, required=True, location="json")
parser.add_argument("role", type=str, required=True, location="json") .add_argument("institution", type=str, required=True, location="json")
.add_argument("role", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"]) return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
@ -399,9 +385,7 @@ class EducationApi(Resource):
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
@marshal_with(status_fields) @marshal_with(status_fields)
def get(self): def get(self):
if not isinstance(current_user, Account): account, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
account = current_user
res = BillingService.EducationIdentity.status(account.id) res = BillingService.EducationIdentity.status(account.id)
# convert expire_at to UTC timestamp from isoformat # convert expire_at to UTC timestamp from isoformat
@ -425,10 +409,12 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
@marshal_with(data_fields) @marshal_with(data_fields)
def get(self): def get(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("keywords", type=str, required=True, location="args") reqparse.RequestParser()
parser.add_argument("page", type=int, required=False, location="args", default=0) .add_argument("keywords", type=str, required=True, location="args")
parser.add_argument("limit", type=int, required=False, location="args", default=20) .add_argument("page", type=int, required=False, location="args", default=0)
.add_argument("limit", type=int, required=False, location="args", default=20)
)
args = parser.parse_args() args = parser.parse_args()
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"]) return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
@ -441,11 +427,14 @@ class ChangeEmailSendEmailApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() current_user, _ = current_account_with_tenant()
parser.add_argument("email", type=email, required=True, location="json") parser = (
parser.add_argument("language", type=str, required=False, location="json") reqparse.RequestParser()
parser.add_argument("phase", type=str, required=False, location="json") .add_argument("email", type=email, required=True, location="json")
parser.add_argument("token", type=str, required=False, location="json") .add_argument("language", type=str, required=False, location="json")
.add_argument("phase", type=str, required=False, location="json")
.add_argument("token", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
@ -467,8 +456,6 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError() raise InvalidTokenError()
user_email = reset_data.get("email", "") user_email = reset_data.get("email", "")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if user_email != current_user.email: if user_email != current_user.email:
raise InvalidEmailError() raise InvalidEmailError()
else: else:
@ -490,10 +477,12 @@ class ChangeEmailCheckApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, location="json") .add_argument("email", type=email, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
user_email = args["email"] user_email = args["email"]
@ -533,9 +522,11 @@ class ChangeEmailResetApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("new_email", type=email, required=True, location="json") reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("new_email", type=email, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if AccountService.is_account_in_freeze(args["new_email"]): if AccountService.is_account_in_freeze(args["new_email"]):
@ -551,8 +542,7 @@ class ChangeEmailResetApi(Resource):
AccountService.revoke_change_email_token(args["token"]) AccountService.revoke_change_email_token(args["token"])
old_email = reset_data.get("old_email", "") old_email = reset_data.get("old_email", "")
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
if current_user.email != old_email: if current_user.email != old_email:
raise AccountNotFound() raise AccountNotFound()
@ -569,8 +559,7 @@ class ChangeEmailResetApi(Resource):
class CheckEmailUnique(Resource): class CheckEmailUnique(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
parser.add_argument("email", type=email, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
if AccountService.is_account_in_freeze(args["email"]): if AccountService.is_account_in_freeze(args["email"]):
raise AccountInFreezeError() raise AccountInFreezeError()

View File

@ -3,8 +3,7 @@ from flask_restx import Resource, fields
from controllers.console import api, console_ns from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from services.agent_service import AgentService from services.agent_service import AgentService
@ -21,12 +20,11 @@ class AgentProviderListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
user = current_user user = current_user
assert user.current_tenant_id is not None
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id tenant_id = current_tenant_id
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id)) return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
@ -45,9 +43,5 @@ class AgentProviderApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider_name: str): def get(self, provider_name: str):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
user = current_user return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))
assert user.current_tenant_id is not None
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))

View File

@ -5,18 +5,10 @@ from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from services.plugin.endpoint_service import EndpointService from services.plugin.endpoint_service import EndpointService
def _current_account_with_tenant() -> tuple[Account, str]:
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
return current_user, tenant_id
@console_ns.route("/workspaces/current/endpoints/create") @console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource): class EndpointCreateApi(Resource):
@api.doc("create_endpoint") @api.doc("create_endpoint")
@ -41,14 +33,16 @@ class EndpointCreateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("plugin_unique_identifier", type=str, required=True) reqparse.RequestParser()
parser.add_argument("settings", type=dict, required=True) .add_argument("plugin_unique_identifier", type=str, required=True)
parser.add_argument("name", type=str, required=True) .add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args() args = parser.parse_args()
plugin_unique_identifier = args["plugin_unique_identifier"] plugin_unique_identifier = args["plugin_unique_identifier"]
@ -87,11 +81,13 @@ class EndpointListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=int, required=True, location="args") reqparse.RequestParser()
parser.add_argument("page_size", type=int, required=True, location="args") .add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()
page = args["page"] page = args["page"]
@ -130,12 +126,14 @@ class EndpointListForSinglePluginApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=int, required=True, location="args") reqparse.RequestParser()
parser.add_argument("page_size", type=int, required=True, location="args") .add_argument("page", type=int, required=True, location="args")
parser.add_argument("plugin_id", type=str, required=True, location="args") .add_argument("page_size", type=int, required=True, location="args")
.add_argument("plugin_id", type=str, required=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()
page = args["page"] page = args["page"]
@ -172,10 +170,9 @@ class EndpointDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
@ -212,12 +209,14 @@ class EndpointUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("endpoint_id", type=str, required=True) reqparse.RequestParser()
parser.add_argument("settings", type=dict, required=True) .add_argument("endpoint_id", type=str, required=True)
parser.add_argument("name", type=str, required=True) .add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args() args = parser.parse_args()
endpoint_id = args["endpoint_id"] endpoint_id = args["endpoint_id"]
@ -255,10 +254,9 @@ class EndpointEnableApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()
endpoint_id = args["endpoint_id"] endpoint_id = args["endpoint_id"]
@ -288,10 +286,9 @@ class EndpointDisableApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = _current_account_with_tenant() user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
parser.add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()
endpoint_id = args["endpoint_id"] endpoint_id = args["endpoint_id"]

View File

@ -5,8 +5,8 @@ from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService from services.model_load_balancing_service import ModelLoadBalancingService
@ -18,24 +18,25 @@ class LoadBalancingCredentialsValidateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role): if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id tenant_id = current_tenant_id
assert tenant_id is not None
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
) )
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
# validate model load balancing credentials # validate model load balancing credentials
@ -72,24 +73,25 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str, config_id: str): def post(self, provider: str, config_id: str):
assert isinstance(current_user, Account) current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role): if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id tenant_id = current_tenant_id
assert tenant_id is not None
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
) )
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
# validate model load balancing config credentials # validate model load balancing config credentials

View File

@ -25,7 +25,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields from fields.member_fields import account_with_role_list_fields
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole from models.account import Account, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError from services.errors.account import AccountAlreadyInTenantError
@ -41,8 +41,7 @@ class MemberListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_with_role_list_fields) @marshal_with(account_with_role_list_fields)
def get(self): def get(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant) members = TenantService.get_tenant_members(current_user.current_tenant)
@ -58,10 +57,12 @@ class MemberInviteEmailApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("members") @cloud_edition_billing_resource_check("members")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("emails", type=list, required=True, location="json") reqparse.RequestParser()
parser.add_argument("role", type=str, required=True, default="admin", location="json") .add_argument("emails", type=list, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json") .add_argument("role", type=str, required=True, default="admin", location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
invitee_emails = args["emails"] invitee_emails = args["emails"]
@ -69,9 +70,7 @@ class MemberInviteEmailApi(Resource):
interface_language = args["language"] interface_language = args["language"]
if not TenantAccountRole.is_non_owner_role(invitee_role): if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400 return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
inviter = current_user inviter = current_user
if not inviter.current_tenant: if not inviter.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
@ -120,8 +119,7 @@ class MemberCancelInviteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, member_id): def delete(self, member_id):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
member = db.session.query(Account).where(Account.id == str(member_id)).first() member = db.session.query(Account).where(Account.id == str(member_id)).first()
@ -153,16 +151,13 @@ class MemberUpdateRoleApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self, member_id): def put(self, member_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
parser.add_argument("role", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
new_role = args["role"] new_role = args["role"]
if not TenantAccountRole.is_valid_role(new_role): if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400 return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id)) member = db.session.get(Account, str(member_id))
@ -189,8 +184,7 @@ class DatasetOperatorMemberListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(account_with_role_list_fields) @marshal_with(account_with_role_list_fields)
def get(self): def get(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant) members = TenantService.get_dataset_operator_members(current_user.current_tenant)
@ -206,16 +200,13 @@ class SendOwnerTransferEmailApi(Resource):
@account_initialization_required @account_initialization_required
@is_allow_transfer_owner @is_allow_transfer_owner
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
current_user, _ = current_account_with_tenant()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant): if not TenantService.is_owner(current_user, current_user.current_tenant):
@ -245,13 +236,14 @@ class OwnerTransferCheckApi(Resource):
@account_initialization_required @account_initialization_required
@is_allow_transfer_owner @is_allow_transfer_owner
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("code", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json") .add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant): if not TenantService.is_owner(current_user, current_user.current_tenant):
@ -291,13 +283,13 @@ class OwnerTransfer(Resource):
@account_initialization_required @account_initialization_required
@is_allow_transfer_owner @is_allow_transfer_owner
def post(self, member_id): def post(self, member_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("token", type=str, required=True, nullable=False, location="json") "token", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.current_tenant: if not current_user.current_tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant): if not TenantService.is_owner(current_user, current_user.current_tenant):

View File

@ -1,7 +1,6 @@
import io import io
from flask import send_file from flask import send_file
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -11,8 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value from libs.helper import StrLen, uuid_value
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account
from services.billing_service import BillingService from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService from services.model_provider_service import ModelProviderService
@ -23,14 +21,10 @@ class ModelProviderListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account") tenant_id = current_tenant_id
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"model_type", "model_type",
type=str, type=str,
required=False, required=False,
@ -52,14 +46,12 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider: str): def get(self, provider: str):
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account") tenant_id = current_tenant_id
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
# if credential_id is not provided, return current used credential # if credential_id is not provided, return current used credential
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") "credential_id", type=uuid_value, required=False, nullable=True, location="args"
)
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -73,23 +65,22 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try: try:
model_provider_service.create_provider_credential( model_provider_service.create_provider_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
credentials=args["credentials"], credentials=args["credentials"],
credential_name=args["name"], credential_name=args["name"],
@ -103,24 +94,23 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self, provider: str): def put(self, provider: str):
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try: try:
model_provider_service.update_provider_credential( model_provider_service.update_provider_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
credentials=args["credentials"], credentials=args["credentials"],
credential_id=args["credential_id"], credential_id=args["credential_id"],
@ -135,19 +125,17 @@ class ModelProviderCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") "credential_id", type=uuid_value, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential( model_provider_service.remove_provider_credential(
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
) )
return {"result": "success"}, 204 return {"result": "success"}, 204
@ -159,19 +147,17 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") "credential_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
service = ModelProviderService() service = ModelProviderService()
service.switch_active_provider_credential( service.switch_active_provider_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
credential_id=args["credential_id"], credential_id=args["credential_id"],
) )
@ -184,15 +170,13 @@ class ModelProviderValidateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument(
parser = reqparse.RequestParser() "credentials", type=dict, required=True, nullable=False, location="json"
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") )
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id: tenant_id = current_tenant_id
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -240,17 +224,13 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
if not current_user.current_tenant_id: tenant_id = current_tenant_id
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument(
"preferred_provider_type", "preferred_provider_type",
type=str, type=str,
required=True, required=True,
@ -276,14 +256,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
def get(self, provider: str): def get(self, provider: str):
if provider != "anthropic": if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid") raise ValueError(f"provider name {provider} is invalid")
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
data = BillingService.get_model_provider_payment_link( data = BillingService.get_model_provider_payment_link(
provider_name=provider, provider_name=provider,
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
account_id=current_user.id, account_id=current_user.id,
prefilled_email=current_user.email, prefilled_email=current_user.email,
) )

View File

@ -1,6 +1,5 @@
import logging import logging
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -10,7 +9,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value from libs.helper import StrLen, uuid_value
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService from services.model_provider_service import ModelProviderService
@ -23,8 +22,9 @@ class DefaultModelApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
parser = reqparse.RequestParser() _, tenant_id = current_account_with_tenant()
parser.add_argument(
parser = reqparse.RequestParser().add_argument(
"model_type", "model_type",
type=str, type=str,
required=True, required=True,
@ -34,8 +34,6 @@ class DefaultModelApi(Resource):
) )
args = parser.parse_args() args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type( default_model_entity = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id, model_type=args["model_type"] tenant_id=tenant_id, model_type=args["model_type"]
@ -47,15 +45,15 @@ class DefaultModelApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") "model_settings", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_settings = args["model_settings"] model_settings = args["model_settings"]
for model_setting in model_settings: for model_setting in model_settings:
@ -92,7 +90,7 @@ class ModelProviderModelApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
@ -104,24 +102,26 @@ class ModelProviderModelApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
# To save the model's load balance configs # To save the model's load balance configs
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id parser = (
reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument("model", type=str, required=True, nullable=False, location="json") .add_argument(
parser.add_argument( "model_type",
"model_type", type=str,
type=str, required=True,
required=True, nullable=False,
nullable=False, choices=[mt.value for mt in ModelType],
choices=[mt.value for mt in ModelType], location="json",
location="json", )
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
) )
parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
if args.get("config_from", "") == "custom-model": if args.get("config_from", "") == "custom-model":
@ -129,7 +129,7 @@ class ModelProviderModelApi(Resource):
raise ValueError("credential_id is required when configuring a custom-model") raise ValueError("credential_id is required when configuring a custom-model")
service = ModelProviderService() service = ModelProviderService()
service.switch_active_custom_model_credential( service.switch_active_custom_model_credential(
tenant_id=current_user.current_tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
model_type=args["model_type"], model_type=args["model_type"],
model=args["model"], model=args["model"],
@ -164,20 +164,22 @@ class ModelProviderModelApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id parser = (
reqparse.RequestParser()
parser = reqparse.RequestParser() .add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument("model", type=str, required=True, nullable=False, location="json") .add_argument(
parser.add_argument( "model_type",
"model_type", type=str,
type=str, required=True,
required=True, nullable=False,
nullable=False, choices=[mt.value for mt in ModelType],
choices=[mt.value for mt in ModelType], location="json",
location="json", )
) )
args = parser.parse_args() args = parser.parse_args()
@ -195,20 +197,22 @@ class ModelProviderModelCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider: str): def get(self, provider: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="args") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="args")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="args", choices=[mt.value for mt in ModelType],
location="args",
)
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
) )
parser.add_argument("config_from", type=str, required=False, nullable=True, location="args")
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -257,24 +261,27 @@ class ModelProviderModelCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
) )
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
try: try:
@ -301,29 +308,33 @@ class ModelProviderModelCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self, provider: str): def put(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
) )
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
try: try:
model_provider_service.update_model_credential( model_provider_service.update_model_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
model_type=args["model_type"], model_type=args["model_type"],
model=args["model"], model=args["model"],
@ -340,24 +351,28 @@ class ModelProviderModelCredentialApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
) )
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential( model_provider_service.remove_model_credential(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
model_type=args["model_type"], model_type=args["model_type"],
model=args["model"], model=args["model"],
@ -373,24 +388,28 @@ class ModelProviderModelCredentialSwitchApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
) )
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
service = ModelProviderService() service = ModelProviderService()
service.add_model_credential_to_model_list( service.add_model_credential_to_model_list(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
model_type=args["model_type"], model_type=args["model_type"],
model=args["model"], model=args["model"],
@ -407,17 +426,19 @@ class ModelProviderModelEnableApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, provider: str): def patch(self, provider: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
) )
args = parser.parse_args() args = parser.parse_args()
@ -437,17 +458,19 @@ class ModelProviderModelDisableApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, provider: str): def patch(self, provider: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
) )
args = parser.parse_args() args = parser.parse_args()
@ -465,19 +488,21 @@ class ModelProviderModelValidateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("model", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument( .add_argument("model", type=str, required=True, nullable=False, location="json")
"model_type", .add_argument(
type=str, "model_type",
required=True, type=str,
nullable=False, required=True,
choices=[mt.value for mt in ModelType], nullable=False,
location="json", choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
) )
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -511,11 +536,11 @@ class ModelProviderModelParameterRuleApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider: str): def get(self, provider: str):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("model", type=str, required=True, nullable=False, location="args") "model", type=str, required=True, nullable=False, location="args"
)
args = parser.parse_args() args = parser.parse_args()
_, tenant_id = current_account_with_tenant()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules( parameter_rules = model_provider_service.get_model_parameter_rules(
@ -531,8 +556,7 @@ class ModelProviderAvailableModelApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, model_type): def get(self, model_type):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)

View File

@ -1,7 +1,6 @@
import io import io
from flask import request, send_file from flask import request, send_file
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -11,7 +10,7 @@ from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService from services.plugin.plugin_parameter_service import PluginParameterService
@ -26,7 +25,7 @@ class PluginDebuggingKeyApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(debug_required=True)
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
try: try:
return { return {
@ -44,10 +43,12 @@ class PluginListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=int, required=False, location="args", default=1) reqparse.RequestParser()
parser.add_argument("page_size", type=int, required=False, location="args", default=256) .add_argument("page", type=int, required=False, location="args", default=1)
.add_argument("page_size", type=int, required=False, location="args", default=256)
)
args = parser.parse_args() args = parser.parse_args()
try: try:
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"]) plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
@ -63,8 +64,7 @@ class PluginListLatestVersionsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
req = reqparse.RequestParser() req = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
req.add_argument("plugin_ids", type=list, required=True, location="json")
args = req.parse_args() args = req.parse_args()
try: try:
@ -81,10 +81,9 @@ class PluginListInstallationsFromIdsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
parser.add_argument("plugin_ids", type=list, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -99,9 +98,11 @@ class PluginListInstallationsFromIdsApi(Resource):
class PluginIconApi(Resource): class PluginIconApi(Resource):
@setup_required @setup_required
def get(self): def get(self):
req = reqparse.RequestParser() req = (
req.add_argument("tenant_id", type=str, required=True, location="args") reqparse.RequestParser()
req.add_argument("filename", type=str, required=True, location="args") .add_argument("tenant_id", type=str, required=True, location="args")
.add_argument("filename", type=str, required=True, location="args")
)
args = req.parse_args() args = req.parse_args()
try: try:
@ -120,7 +121,7 @@ class PluginUploadFromPkgApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
file = request.files["pkg"] file = request.files["pkg"]
@ -144,12 +145,14 @@ class PluginUploadFromGithubApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("repo", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("version", type=str, required=True, location="json") .add_argument("repo", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json") .add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -167,7 +170,7 @@ class PluginUploadFromBundleApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
file = request.files["bundle"] file = request.files["bundle"]
@ -191,10 +194,11 @@ class PluginInstallFromPkgApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") "plugin_unique_identifiers", type=list, required=True, location="json"
)
args = parser.parse_args() args = parser.parse_args()
# check if all plugin_unique_identifiers are valid string # check if all plugin_unique_identifiers are valid string
@ -217,13 +221,15 @@ class PluginInstallFromGithubApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("repo", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("version", type=str, required=True, location="json") .add_argument("repo", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json") .add_argument("version", type=str, required=True, location="json")
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json") .add_argument("package", type=str, required=True, location="json")
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -247,10 +253,11 @@ class PluginInstallFromMarketplaceApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") "plugin_unique_identifiers", type=list, required=True, location="json"
)
args = parser.parse_args() args = parser.parse_args()
# check if all plugin_unique_identifiers are valid string # check if all plugin_unique_identifiers are valid string
@ -273,10 +280,11 @@ class PluginFetchMarketplacePkgApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") "plugin_unique_identifier", type=str, required=True, location="args"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -299,10 +307,11 @@ class PluginFetchManifestApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") "plugin_unique_identifier", type=str, required=True, location="args"
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -324,11 +333,13 @@ class PluginFetchInstallTasksApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=int, required=True, location="args") reqparse.RequestParser()
parser.add_argument("page_size", type=int, required=True, location="args") .add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -346,7 +357,7 @@ class PluginFetchInstallTaskApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def get(self, task_id: str): def get(self, task_id: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
try: try:
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)}) return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
@ -361,7 +372,7 @@ class PluginDeleteInstallTaskApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self, task_id: str): def post(self, task_id: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
try: try:
return {"success": PluginService.delete_install_task(tenant_id, task_id)} return {"success": PluginService.delete_install_task(tenant_id, task_id)}
@ -376,7 +387,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
try: try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)} return {"success": PluginService.delete_all_install_task_items(tenant_id)}
@ -391,7 +402,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self, task_id: str, identifier: str): def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
try: try:
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)} return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
@ -406,11 +417,13 @@ class PluginUpgradeFromMarketplaceApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") .add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -430,14 +443,16 @@ class PluginUpgradeFromGithubApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser() parser = (
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") .add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("repo", type=str, required=True, location="json") .add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
parser.add_argument("version", type=str, required=True, location="json") .add_argument("repo", type=str, required=True, location="json")
parser.add_argument("package", type=str, required=True, location="json") .add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -462,11 +477,10 @@ class PluginUninstallApi(Resource):
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
req = reqparse.RequestParser() req = reqparse.RequestParser().add_argument("plugin_installation_id", type=str, required=True, location="json")
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args() args = req.parse_args()
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
try: try:
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])} return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
@ -480,19 +494,22 @@ class PluginChangePermissionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user user = current_user
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
req = reqparse.RequestParser() req = (
req.add_argument("install_permission", type=str, required=True, location="json") reqparse.RequestParser()
req.add_argument("debug_permission", type=str, required=True, location="json") .add_argument("install_permission", type=str, required=True, location="json")
.add_argument("debug_permission", type=str, required=True, location="json")
)
args = req.parse_args() args = req.parse_args()
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"]) install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"]) debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
tenant_id = user.current_tenant_id tenant_id = current_tenant_id
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)} return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
@ -503,7 +520,7 @@ class PluginFetchPermissionApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id) permission = PluginPermissionService.get_permission(tenant_id)
if not permission: if not permission:
@ -529,18 +546,20 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
@account_initialization_required @account_initialization_required
def get(self): def get(self):
# check if the user is admin or owner # check if the user is admin or owner
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id
user_id = current_user.id user_id = current_user.id
parser = reqparse.RequestParser() parser = (
parser.add_argument("plugin_id", type=str, required=True, location="args") reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, location="args") .add_argument("plugin_id", type=str, required=True, location="args")
parser.add_argument("action", type=str, required=True, location="args") .add_argument("provider", type=str, required=True, location="args")
parser.add_argument("parameter", type=str, required=True, location="args") .add_argument("action", type=str, required=True, location="args")
parser.add_argument("provider_type", type=str, required=True, location="args") .add_argument("parameter", type=str, required=True, location="args")
.add_argument("provider_type", type=str, required=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -565,17 +584,17 @@ class PluginChangePreferencesApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
req = reqparse.RequestParser() req = (
req.add_argument("permission", type=dict, required=True, location="json") reqparse.RequestParser()
req.add_argument("auto_upgrade", type=dict, required=True, location="json") .add_argument("permission", type=dict, required=True, location="json")
.add_argument("auto_upgrade", type=dict, required=True, location="json")
)
args = req.parse_args() args = req.parse_args()
tenant_id = user.current_tenant_id
permission = args["permission"] permission = args["permission"]
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone")) install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
@ -621,7 +640,7 @@ class PluginFetchPreferencesApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id) permission = PluginPermissionService.get_permission(tenant_id)
permission_dict = { permission_dict = {
@ -661,10 +680,9 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
@account_initialization_required @account_initialization_required
def post(self): def post(self):
# exclude one single plugin # exclude one single plugin
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
req = reqparse.RequestParser() req = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
req.add_argument("plugin_id", type=str, required=True, location="json")
args = req.parse_args() args = req.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])}) return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})

View File

@ -2,7 +2,6 @@ import io
from urllib.parse import urlparse from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file from flask import make_response, redirect, request, send_file
from flask_login import current_user
from flask_restx import ( from flask_restx import (
Resource, Resource,
reqparse, reqparse,
@ -24,7 +23,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required from libs.login import current_account_with_tenant, login_required
from models.provider_ids import ToolProviderID from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.api_tools_manage_service import ApiToolManageService
@ -53,13 +52,11 @@ class ToolProviderListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
req = reqparse.RequestParser() req = reqparse.RequestParser().add_argument(
req.add_argument(
"type", "type",
type=str, type=str,
choices=["builtin", "model", "api", "workflow", "mcp"], choices=["builtin", "model", "api", "workflow", "mcp"],
@ -78,9 +75,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
user = current_user _, tenant_id = current_account_with_tenant()
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.list_builtin_tool_provider_tools( BuiltinToolManageService.list_builtin_tool_provider_tools(
@ -96,9 +91,7 @@ class ToolBuiltinProviderInfoApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
user = current_user _, tenant_id = current_account_with_tenant()
tenant_id = user.current_tenant_id
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@ -109,13 +102,13 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
tenant_id = user.current_tenant_id req = reqparse.RequestParser().add_argument(
req = reqparse.RequestParser() "credential_id", type=str, required=True, nullable=False, location="json"
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json") )
args = req.parse_args() args = req.parse_args()
return BuiltinToolManageService.delete_builtin_tool_provider( return BuiltinToolManageService.delete_builtin_tool_provider(
@ -131,15 +124,16 @@ class ToolBuiltinProviderAddApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = (
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json") .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("type", type=str, required=True, nullable=False, location="json") .add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
.add_argument("type", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if args["type"] not in CredentialType.values(): if args["type"] not in CredentialType.values():
@ -161,18 +155,19 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = (
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") .add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") .add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
@ -193,7 +188,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials( BuiltinToolManageService.get_builtin_tool_provider_credentials(
@ -218,23 +213,24 @@ class ToolApiProviderAddApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = (
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("schema", type=str, required=True, nullable=False, location="json") .add_argument("schema_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider", type=str, required=True, nullable=False, location="json") .add_argument("schema", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") .add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") .add_argument("icon", type=dict, required=True, nullable=False, location="json")
parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") .add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
@ -258,14 +254,11 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
parser.add_argument("url", type=str, required=True, nullable=False, location="args")
args = parser.parse_args() args = parser.parse_args()
@ -282,14 +275,13 @@ class ToolApiProviderListToolsApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
parser.add_argument("provider", type=str, required=True, nullable=False, location="args") )
args = parser.parse_args() args = parser.parse_args()
@ -308,24 +300,25 @@ class ToolApiProviderUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = (
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("schema", type=str, required=True, nullable=False, location="json") .add_argument("schema_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider", type=str, required=True, nullable=False, location="json") .add_argument("schema", type=str, required=True, nullable=False, location="json")
parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json") .add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") .add_argument("original_provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") .add_argument("icon", type=dict, required=True, nullable=False, location="json")
parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") .add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") .add_argument("labels", type=list[str], required=False, nullable=True, location="json")
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
@ -350,17 +343,16 @@ class ToolApiProviderDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="json"
parser.add_argument("provider", type=str, required=True, nullable=False, location="json") )
args = parser.parse_args() args = parser.parse_args()
@ -377,14 +369,13 @@ class ToolApiProviderGetApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
"provider", type=str, required=True, nullable=False, location="args"
parser.add_argument("provider", type=str, required=True, nullable=False, location="args") )
args = parser.parse_args() args = parser.parse_args()
@ -401,8 +392,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider, credential_type): def get(self, provider, credential_type):
user = current_user _, tenant_id = current_account_with_tenant()
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.list_builtin_provider_credentials_schema( BuiltinToolManageService.list_builtin_provider_credentials_schema(
@ -417,9 +407,9 @@ class ToolApiProviderSchemaApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
"schema", type=str, required=True, nullable=False, location="json"
parser.add_argument("schema", type=str, required=True, nullable=False, location="json") )
args = parser.parse_args() args = parser.parse_args()
@ -434,19 +424,20 @@ class ToolApiProviderPreviousTestApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
reqparse.RequestParser()
parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json") .add_argument("tool_name", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json") .add_argument("provider_name", type=str, required=False, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") .add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json") .add_argument("parameters", type=dict, required=True, nullable=False, location="json")
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") .add_argument("schema_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("schema", type=str, required=True, nullable=False, location="json") .add_argument("schema", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args() args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
return ApiToolManageService.test_api_tool_preview( return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id, current_tenant_id,
args["provider_name"] or "", args["provider_name"] or "",
args["tool_name"], args["tool_name"],
args["credentials"], args["credentials"],
@ -462,23 +453,24 @@ class ToolWorkflowProviderCreateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser() reqparser = (
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") reqparse.RequestParser()
reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") .add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") .add_argument("label", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") .add_argument("description", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") .add_argument("icon", type=dict, required=True, nullable=False, location="json")
reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
args = reqparser.parse_args() args = reqparser.parse_args()
@ -502,23 +494,24 @@ class ToolWorkflowProviderUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser() reqparser = (
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") reqparse.RequestParser()
reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") .add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") .add_argument("label", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") .add_argument("description", type=str, required=True, nullable=False, location="json")
reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") .add_argument("icon", type=dict, required=True, nullable=False, location="json")
reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
)
args = reqparser.parse_args() args = reqparser.parse_args()
@ -545,16 +538,16 @@ class ToolWorkflowProviderDeleteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser() reqparser = reqparse.RequestParser().add_argument(
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
)
args = reqparser.parse_args() args = reqparser.parse_args()
@ -571,14 +564,15 @@ class ToolWorkflowProviderGetApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = (
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") reqparse.RequestParser()
parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") .add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
)
args = parser.parse_args() args = parser.parse_args()
@ -606,13 +600,13 @@ class ToolWorkflowProviderListToolApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
)
args = parser.parse_args() args = parser.parse_args()
@ -631,10 +625,9 @@ class ToolBuiltinListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
[ [
@ -653,8 +646,7 @@ class ToolApiListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user _, tenant_id = current_account_with_tenant()
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
[ [
@ -672,10 +664,9 @@ class ToolWorkflowListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user, tenant_id = current_account_with_tenant()
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
[ [
@ -709,19 +700,18 @@ class ToolPluginOAuthApi(Resource):
provider_name = tool_provider.provider_name provider_name = tool_provider.provider_name
# todo check permission # todo check permission
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
tenant_id = user.current_tenant_id
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider) oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None: if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider") raise Forbidden("no oauth available client config found for this tool provider")
oauth_handler = OAuthHandler() oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context( context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name user_id=user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
) )
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url( authorization_url_response = oauth_handler.get_authorization_url(
@ -800,11 +790,11 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
parser = reqparse.RequestParser() current_user, current_tenant_id = current_account_with_tenant()
parser.add_argument("id", type=str, required=True, nullable=False, location="json") parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
return BuiltinToolManageService.set_default_provider( return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
) )
@ -814,18 +804,20 @@ class ToolOAuthCustomClient(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
parser = reqparse.RequestParser() parser = (
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") reqparse.RequestParser()
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") .add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
user = current_user user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
return BuiltinToolManageService.save_custom_oauth_client_params( return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
client_params=args.get("client_params", {}), client_params=args.get("client_params", {}),
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
@ -835,20 +827,18 @@ class ToolOAuthCustomClient(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params( BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
tenant_id=current_user.current_tenant_id, provider=provider
)
) )
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider): def delete(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params( BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
tenant_id=current_user.current_tenant_id, provider=provider
)
) )
@ -858,9 +848,10 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema( BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
tenant_id=current_user.current_tenant_id, provider_name=provider tenant_id=current_tenant_id, provider_name=provider
) )
) )
@ -871,7 +862,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info( BuiltinToolManageService.get_builtin_tool_provider_credential_info(
@ -887,25 +878,25 @@ class ToolProviderMCPApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") .add_argument("server_url", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json") .add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") .add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") .add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") .add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30) .add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument( .add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300 .add_argument("sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300)
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
) )
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
args = parser.parse_args() args = parser.parse_args()
user = current_user user, tenant_id = current_account_with_tenant()
if not is_valid_url(args["server_url"]): if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.") raise ValueError("Server URL is not valid.")
return jsonable_encoder( return jsonable_encoder(
MCPToolManageService.create_mcp_provider( MCPToolManageService.create_mcp_provider(
tenant_id=user.current_tenant_id, tenant_id=tenant_id,
server_url=args["server_url"], server_url=args["server_url"],
name=args["name"], name=args["name"],
icon=args["icon"], icon=args["icon"],
@ -923,25 +914,28 @@ class ToolProviderMCPApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self): def put(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") .add_argument("server_url", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json") .add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") .add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") .add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") .add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") .add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json") .add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json") .add_argument("timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json") .add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
.add_argument("headers", type=dict, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
if not is_valid_url(args["server_url"]): if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]: if "[__HIDDEN__]" in args["server_url"]:
pass pass
else: else:
raise ValueError("Server URL is not valid.") raise ValueError("Server URL is not valid.")
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.update_mcp_provider( MCPToolManageService.update_mcp_provider(
tenant_id=current_user.current_tenant_id, tenant_id=current_tenant_id,
provider_id=args["provider_id"], provider_id=args["provider_id"],
server_url=args["server_url"], server_url=args["server_url"],
name=args["name"], name=args["name"],
@ -959,10 +953,12 @@ class ToolProviderMCPApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self): def delete(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument(
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") "provider_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"]) _, current_tenant_id = current_account_with_tenant()
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"} return {"result": "success"}
@ -972,12 +968,14 @@ class ToolMCPAuthApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") reqparse.RequestParser()
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json") .add_argument("provider_id", type=str, required=True, nullable=False, location="json")
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
provider_id = args["provider_id"] provider_id = args["provider_id"]
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id) provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider: if not provider:
raise ValueError("provider not found") raise ValueError("provider not found")
@ -1018,8 +1016,8 @@ class ToolMCPDetailApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider_id): def get(self, provider_id):
user = current_user _, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id) provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@ -1029,8 +1027,7 @@ class ToolMCPListAllApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user _, tenant_id = current_account_with_tenant()
tenant_id = user.current_tenant_id
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id) tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
@ -1043,7 +1040,7 @@ class ToolMCPUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider_id): def get(self, provider_id):
tenant_id = current_user.current_tenant_id _, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.list_mcp_tool_from_remote_server( tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id, tenant_id=tenant_id,
provider_id=provider_id, provider_id=provider_id,
@ -1054,9 +1051,11 @@ class ToolMCPUpdateApi(Resource):
@console_ns.route("/mcp/oauth/callback") @console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource): class ToolMCPCallbackApi(Resource):
def get(self): def get(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("code", type=str, required=True, nullable=False, location="args") reqparse.RequestParser()
parser.add_argument("state", type=str, required=True, nullable=False, location="args") .add_argument("code", type=str, required=True, nullable=False, location="args")
.add_argument("state", type=str, required=True, nullable=False, location="args")
)
args = parser.parse_args() args = parser.parse_args()
state_key = args["state"] state_key = args["state"]
authorization_code = args["code"] authorization_code = args["code"]

View File

@ -23,8 +23,8 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from libs.helper import TimestampField
from libs.login import current_user, login_required from libs.login import current_account_with_tenant, login_required
from models.account import Account, Tenant, TenantStatus from models.account import Tenant, TenantStatus
from services.account_service import TenantService from services.account_service import TenantService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.file_service import FileService from services.file_service import FileService
@ -70,8 +70,7 @@ class TenantListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
if not isinstance(current_user, Account): current_user, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account")
tenants = TenantService.get_join_tenants(current_user) tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = [] tenant_dicts = []
@ -85,7 +84,7 @@ class TenantListApi(Resource):
"status": tenant.status, "status": tenant.status,
"created_at": tenant.created_at, "created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False, "current": tenant.id == current_tenant_id if current_tenant_id else False,
} }
tenant_dicts.append(tenant_dict) tenant_dicts.append(tenant_dict)
@ -98,9 +97,11 @@ class WorkspaceListApi(Resource):
@setup_required @setup_required
@admin_required @admin_required
def get(self): def get(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") reqparse.RequestParser()
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args() args = parser.parse_args()
stmt = select(Tenant).order_by(Tenant.created_at.desc()) stmt = select(Tenant).order_by(Tenant.created_at.desc())
@ -130,8 +131,7 @@ class TenantApi(Resource):
if request.path == "/info": if request.path == "/info":
logger.warning("Deprecated URL /info was used.") logger.warning("Deprecated URL /info was used.")
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
tenant = current_user.current_tenant tenant = current_user.current_tenant
if not tenant: if not tenant:
raise ValueError("No current tenant") raise ValueError("No current tenant")
@ -155,10 +155,8 @@ class SwitchWorkspaceApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
# check if tenant_id is valid, 403 if not # check if tenant_id is valid, 403 if not
@ -181,16 +179,14 @@ class CustomConfigWorkspaceApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom") @cloud_edition_billing_resource_check("workspace_custom")
def post(self): def post(self):
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account") parser = (
parser = reqparse.RequestParser() reqparse.RequestParser()
parser.add_argument("remove_webapp_brand", type=bool, location="json") .add_argument("remove_webapp_brand", type=bool, location="json")
parser.add_argument("replace_webapp_logo", type=str, location="json") .add_argument("replace_webapp_logo", type=str, location="json")
)
args = parser.parse_args() args = parser.parse_args()
tenant = db.get_or_404(Tenant, current_tenant_id)
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
custom_config_dict = { custom_config_dict = {
"remove_webapp_brand": args["remove_webapp_brand"], "remove_webapp_brand": args["remove_webapp_brand"],
@ -212,8 +208,7 @@ class WebappLogoWorkspaceApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom") @cloud_edition_billing_resource_check("workspace_custom")
def post(self): def post(self):
if not isinstance(current_user, Account): current_user, _ = current_account_with_tenant()
raise ValueError("Invalid user account")
# check file # check file
if "file" not in request.files: if "file" not in request.files:
raise NoFileUploadedError() raise NoFileUploadedError()
@ -253,15 +248,13 @@ class WorkspaceInfoApi(Resource):
@account_initialization_required @account_initialization_required
# Change workspace name # Change workspace name
def post(self): def post(self):
if not isinstance(current_user, Account): _, current_tenant_id = current_account_with_tenant()
raise ValueError("Invalid user account") parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.current_tenant_id: if not current_tenant_id:
raise ValueError("No current tenant") raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id) tenant = db.get_or_404(Tenant, current_tenant_id)
tenant.name = args["name"] tenant.name = args["name"]
db.session.commit() db.session.commit()

View File

@ -12,8 +12,8 @@ from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs.login import current_user from libs.login import current_account_with_tenant
from models.account import Account, AccountStatus from models.account import AccountStatus
from models.dataset import RateLimitLog from models.dataset import RateLimitLog
from models.model import DifySetup from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus from services.feature_service import FeatureService, LicenseStatus
@ -25,18 +25,12 @@ P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
def _current_account() -> Account:
assert isinstance(current_user, Account)
return current_user
def account_initialization_required(view: Callable[P, R]): def account_initialization_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization # check account initialization
account = _current_account() current_user, _ = current_account_with_tenant()
if current_user.status == AccountStatus.UNINITIALIZED:
if account.status == AccountStatus.UNINITIALIZED:
raise AccountNotInitializedError() raise AccountNotInitializedError()
return view(*args, **kwargs) return view(*args, **kwargs)
@ -80,9 +74,8 @@ def only_edition_self_hosted(view: Callable[P, R]):
def cloud_edition_billing_enabled(view: Callable[P, R]): def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None features = FeatureService.get_features(current_tenant_id)
features = FeatureService.get_features(account.current_tenant_id)
if not features.billing.enabled: if not features.billing.enabled:
abort(403, "Billing feature is not enabled.") abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs) return view(*args, **kwargs)
@ -94,10 +87,8 @@ def cloud_edition_billing_resource_check(resource: str):
def interceptor(view: Callable[P, R]): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None features = FeatureService.get_features(current_tenant_id)
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
if features.billing.enabled: if features.billing.enabled:
members = features.members members = features.members
apps = features.apps apps = features.apps
@ -138,9 +129,8 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view: Callable[P, R]): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None features = FeatureService.get_features(current_tenant_id)
features = FeatureService.get_features(account.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
if resource == "add_segment": if resource == "add_segment":
if features.billing.subscription.plan == "sandbox": if features.billing.subscription.plan == "sandbox":
@ -163,13 +153,11 @@ def cloud_edition_billing_rate_limit_check(resource: str):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge": if resource == "knowledge":
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_tenant_id)
tenant_id = account.current_tenant_id
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
if knowledge_rate_limit.enabled: if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000) current_time = int(time.time() * 1000)
key = f"rate_limit_{tenant_id}" key = f"rate_limit_{current_tenant_id}"
redis_client.zadd(key, {current_time: current_time}) redis_client.zadd(key, {current_time: current_time})
@ -180,7 +168,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
if request_count > knowledge_rate_limit.limit: if request_count > knowledge_rate_limit.limit:
# add ratelimit record # add ratelimit record
rate_limit_log = RateLimitLog( rate_limit_log = RateLimitLog(
tenant_id=tenant_id, tenant_id=current_tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan, subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge", operation="knowledge",
) )
@ -200,17 +188,15 @@ def cloud_utm_record(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None features = FeatureService.get_features(current_tenant_id)
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
if features.billing.enabled: if features.billing.enabled:
utm_info = request.cookies.get("utm_info") utm_info = request.cookies.get("utm_info")
if utm_info: if utm_info:
utm_info_dict: dict = json.loads(utm_info) utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(tenant_id, utm_info_dict) OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs) return view(*args, **kwargs)
@ -260,9 +246,9 @@ def email_password_login_enabled(view: Callable[P, R]):
return decorated return decorated
def email_register_enabled(view): def email_register_enabled(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
if features.is_allow_register: if features.is_allow_register:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -289,9 +275,8 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]): def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None features = FeatureService.get_features(current_tenant_id)
features = FeatureService.get_features(account.current_tenant_id)
if features.is_allow_transfer_workspace: if features.is_allow_transfer_workspace:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -301,14 +286,31 @@ def is_allow_transfer_owner(view: Callable[P, R]):
return decorated return decorated
def knowledge_pipeline_publish_enabled(view): def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account() _, current_tenant_id = current_account_with_tenant()
assert account.current_tenant_id is not None features = FeatureService.get_features(current_tenant_id)
features = FeatureService.get_features(account.current_tenant_id)
if features.knowledge_pipeline.publish_enabled: if features.knowledge_pipeline.publish_enabled:
return view(*args, **kwargs) return view(*args, **kwargs)
abort(403) abort(403)
return decorated return decorated
def edit_permission_required(f: Callable[P, R]):
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
from libs.login import current_user
from models import Account
user = current_user._get_current_object() # type: ignore
if not isinstance(user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)
return decorated_function

View File

@ -46,11 +46,13 @@ class FilePreviewApi(Resource):
def get(self, file_id): def get(self, file_id):
file_id = str(file_id) file_id = str(file_id)
parser = reqparse.RequestParser() parser = (
parser.add_argument("timestamp", type=str, required=True, location="args") reqparse.RequestParser()
parser.add_argument("nonce", type=str, required=True, location="args") .add_argument("timestamp", type=str, required=True, location="args")
parser.add_argument("sign", type=str, required=True, location="args") .add_argument("nonce", type=str, required=True, location="args")
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") .add_argument("sign", type=str, required=True, location="args")
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args() args = parser.parse_args()

View File

@ -16,12 +16,13 @@ class ToolFileApi(Resource):
def get(self, file_id, extension): def get(self, file_id, extension):
file_id = str(file_id) file_id = str(file_id)
parser = reqparse.RequestParser() parser = (
reqparse.RequestParser()
parser.add_argument("timestamp", type=str, required=True, location="args") .add_argument("timestamp", type=str, required=True, location="args")
parser.add_argument("nonce", type=str, required=True, location="args") .add_argument("nonce", type=str, required=True, location="args")
parser.add_argument("sign", type=str, required=True, location="args") .add_argument("sign", type=str, required=True, location="args")
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") .add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args() args = parser.parse_args()
if not verify_tool_file_signature( if not verify_tool_file_signature(

View File

@ -18,19 +18,17 @@ from core.tools.tool_file_manager import ToolFileManager
from fields.file_fields import build_file_model from fields.file_fields import build_file_model
# Define parser for both documentation and validation # Define parser for both documentation and validation
upload_parser = reqparse.RequestParser() upload_parser = (
upload_parser.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload") reqparse.RequestParser()
upload_parser.add_argument( .add_argument("file", location="files", type=FileStorage, required=True, help="File to upload")
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification" .add_argument(
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification"
)
.add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification")
.add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation")
.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
) )
upload_parser.add_argument(
"nonce", type=str, required=True, location="args", help="Random string for signature verification"
)
upload_parser.add_argument(
"sign", type=str, required=True, location="args", help="HMAC signature for request validation"
)
upload_parser.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
upload_parser.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
@files_ns.route("/upload/for-plugin") @files_ns.route("/upload/for-plugin")

View File

@ -5,11 +5,13 @@ from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
from tasks.mail_inner_task import send_inner_email_task from tasks.mail_inner_task import send_inner_email_task
_mail_parser = reqparse.RequestParser() _mail_parser = (
_mail_parser.add_argument("to", type=str, action="append", required=True) reqparse.RequestParser()
_mail_parser.add_argument("subject", type=str, required=True) .add_argument("to", type=str, action="append", required=True)
_mail_parser.add_argument("body", type=str, required=True) .add_argument("subject", type=str, required=True)
_mail_parser.add_argument("substitutions", type=dict, required=False) .add_argument("body", type=str, required=True)
.add_argument("substitutions", type=dict, required=False)
)
class BaseMail(Resource): class BaseMail(Resource):
@ -17,7 +19,7 @@ class BaseMail(Resource):
def post(self): def post(self):
args = _mail_parser.parse_args() args = _mail_parser.parse_args()
send_inner_email_task.delay( send_inner_email_task.delay( # type: ignore
to=args["to"], to=args["to"],
subject=args["subject"], subject=args["subject"],
body=args["body"], body=args["body"],

View File

@ -31,7 +31,7 @@ from core.plugin.entities.request import (
) )
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from libs.helper import length_prefixed_response from libs.helper import length_prefixed_response
from models.account import Account, Tenant from models import Account, Tenant
from models.model import EndUser from models.model import EndUser

View File

@ -72,9 +72,11 @@ def get_user_tenant(view: Callable[P, R] | None = None):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs): def decorated_view(*args: P.args, **kwargs: P.kwargs):
# fetch json body # fetch json body
parser = reqparse.RequestParser() parser = (
parser.add_argument("tenant_id", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("user_id", type=str, required=True, location="json") .add_argument("tenant_id", type=str, required=True, location="json")
.add_argument("user_id", type=str, required=True, location="json")
)
p = parser.parse_args() p = parser.parse_args()

View File

@ -7,7 +7,7 @@ from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import enterprise_inner_api_only from controllers.inner_api.wraps import enterprise_inner_api_only
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models import Account
from services.account_service import TenantService from services.account_service import TenantService
@ -25,9 +25,11 @@ class EnterpriseWorkspace(Resource):
} }
) )
def post(self): def post(self):
parser = reqparse.RequestParser() parser = (
parser.add_argument("name", type=str, required=True, location="json") reqparse.RequestParser()
parser.add_argument("owner_email", type=str, required=True, location="json") .add_argument("name", type=str, required=True, location="json")
.add_argument("owner_email", type=str, required=True, location="json")
)
args = parser.parse_args() args = parser.parse_args()
account = db.session.query(Account).filter_by(email=args["owner_email"]).first() account = db.session.query(Account).filter_by(email=args["owner_email"]).first()
@ -68,8 +70,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
} }
) )
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)

View File

@ -33,14 +33,12 @@ def int_or_str(value):
# Define parser for both documentation and validation # Define parser for both documentation and validation
mcp_request_parser = reqparse.RequestParser() mcp_request_parser = (
mcp_request_parser.add_argument( reqparse.RequestParser()
"jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')" .add_argument("jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')")
) .add_argument("method", type=str, required=True, location="json", help="The method to invoke")
mcp_request_parser.add_argument("method", type=str, required=True, location="json", help="The method to invoke") .add_argument("params", type=dict, required=False, location="json", help="Parameters for the method")
mcp_request_parser.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method") .add_argument("id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses")
mcp_request_parser.add_argument(
"id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses"
) )

View File

@ -10,24 +10,24 @@ from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_fields, build_annotation_model from fields.annotation_fields import annotation_fields, build_annotation_model
from libs.login import current_user from libs.login import current_user
from models.account import Account from models import Account
from models.model import App from models.model import App
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
# Define parsers for annotation API # Define parsers for annotation API
annotation_create_parser = reqparse.RequestParser() annotation_create_parser = (
annotation_create_parser.add_argument("question", required=True, type=str, location="json", help="Annotation question") reqparse.RequestParser()
annotation_create_parser.add_argument("answer", required=True, type=str, location="json", help="Annotation answer") .add_argument("question", required=True, type=str, location="json", help="Annotation question")
.add_argument("answer", required=True, type=str, location="json", help="Annotation answer")
)
annotation_reply_action_parser = reqparse.RequestParser() annotation_reply_action_parser = (
annotation_reply_action_parser.add_argument( reqparse.RequestParser()
"score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching" .add_argument(
) "score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching"
annotation_reply_action_parser.add_argument( )
"embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name" .add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name")
) .add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name")
annotation_reply_action_parser.add_argument(
"embedding_model_name", required=True, type=str, location="json", help="Embedding model name"
) )

View File

@ -85,11 +85,13 @@ class AudioApi(Resource):
# Define parser for text-to-audio API # Define parser for text-to-audio API
text_to_audio_parser = reqparse.RequestParser() text_to_audio_parser = (
text_to_audio_parser.add_argument("message_id", type=str, required=False, location="json", help="Message ID") reqparse.RequestParser()
text_to_audio_parser.add_argument("voice", type=str, location="json", help="Voice to use for TTS") .add_argument("message_id", type=str, required=False, location="json", help="Message ID")
text_to_audio_parser.add_argument("text", type=str, location="json", help="Text to convert to audio") .add_argument("voice", type=str, location="json", help="Voice to use for TTS")
text_to_audio_parser.add_argument("streaming", type=bool, location="json", help="Enable streaming response") .add_argument("text", type=str, location="json", help="Text to convert to audio")
.add_argument("streaming", type=bool, location="json", help="Enable streaming response")
)
@service_api_ns.route("/text-to-audio") @service_api_ns.route("/text-to-audio")

View File

@ -37,40 +37,34 @@ logger = logging.getLogger(__name__)
# Define parser for completion API # Define parser for completion API
completion_parser = reqparse.RequestParser() completion_parser = (
completion_parser.add_argument( reqparse.RequestParser()
"inputs", type=dict, required=True, location="json", help="Input parameters for completion" .add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion")
) .add_argument("query", type=str, location="json", default="", help="The query string")
completion_parser.add_argument("query", type=str, location="json", default="", help="The query string") .add_argument("files", type=list, required=False, location="json", help="List of file attachments")
completion_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
completion_parser.add_argument( .add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
"response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode"
)
completion_parser.add_argument(
"retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source"
) )
# Define parser for chat API # Define parser for chat API
chat_parser = reqparse.RequestParser() chat_parser = (
chat_parser.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat") reqparse.RequestParser()
chat_parser.add_argument("query", type=str, required=True, location="json", help="The chat query") .add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat")
chat_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") .add_argument("query", type=str, required=True, location="json", help="The chat query")
chat_parser.add_argument( .add_argument("files", type=list, required=False, location="json", help="List of file attachments")
"response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode" .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
.add_argument(
"auto_generate_name",
type=bool,
required=False,
default=True,
location="json",
help="Auto generate conversation name",
)
.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
) )
chat_parser.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
chat_parser.add_argument(
"retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source"
)
chat_parser.add_argument(
"auto_generate_name",
type=bool,
required=False,
default=True,
location="json",
help="Auto generate conversation name",
)
chat_parser.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
@service_api_ns.route("/completion-messages") @service_api_ns.route("/completion-messages")

View File

@ -24,48 +24,63 @@ from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
# Define parsers for conversation APIs # Define parsers for conversation APIs
conversation_list_parser = reqparse.RequestParser() conversation_list_parser = (
conversation_list_parser.add_argument( reqparse.RequestParser()
"last_id", type=uuid_value, location="args", help="Last conversation ID for pagination" .add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination")
) .add_argument(
conversation_list_parser.add_argument( "limit",
"limit", type=int_range(1, 100),
type=int_range(1, 100), required=False,
required=False, default=20,
default=20, location="args",
location="args", help="Number of conversations to return",
help="Number of conversations to return", )
) .add_argument(
conversation_list_parser.add_argument( "sort_by",
"sort_by", type=str,
type=str, choices=["created_at", "-created_at", "updated_at", "-updated_at"],
choices=["created_at", "-created_at", "updated_at", "-updated_at"], required=False,
required=False, default="-updated_at",
default="-updated_at", location="args",
location="args", help="Sort order for conversations",
help="Sort order for conversations", )
) )
conversation_rename_parser = reqparse.RequestParser() conversation_rename_parser = (
conversation_rename_parser.add_argument("name", type=str, required=False, location="json", help="New conversation name") reqparse.RequestParser()
conversation_rename_parser.add_argument( .add_argument("name", type=str, required=False, location="json", help="New conversation name")
"auto_generate", type=bool, required=False, default=False, location="json", help="Auto-generate conversation name" .add_argument(
"auto_generate",
type=bool,
required=False,
default=False,
location="json",
help="Auto-generate conversation name",
)
) )
conversation_variables_parser = reqparse.RequestParser() conversation_variables_parser = (
conversation_variables_parser.add_argument( reqparse.RequestParser()
"last_id", type=uuid_value, location="args", help="Last variable ID for pagination" .add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination")
) .add_argument(
conversation_variables_parser.add_argument( "limit",
"limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of variables to return" type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of variables to return",
)
) )
conversation_variable_update_parser = reqparse.RequestParser() conversation_variable_update_parser = reqparse.RequestParser().add_argument(
# using lambda is for passing the already-typed value without modification # using lambda is for passing the already-typed value without modification
# if no lambda, it will be converted to string # if no lambda, it will be converted to string
# the string cannot be converted using json.loads # the string cannot be converted using json.loads
conversation_variable_update_parser.add_argument( "value",
"value", required=True, location="json", type=lambda x: x, help="New value for the conversation variable" required=True,
location="json",
type=lambda x: x,
help="New value for the conversation variable",
) )

View File

@ -18,8 +18,7 @@ logger = logging.getLogger(__name__)
# Define parser for file preview API # Define parser for file preview API
file_preview_parser = reqparse.RequestParser() file_preview_parser = reqparse.RequestParser().add_argument(
file_preview_parser.add_argument(
"as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment" "as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment"
) )

View File

@ -26,25 +26,37 @@ logger = logging.getLogger(__name__)
# Define parsers for message APIs # Define parsers for message APIs
message_list_parser = reqparse.RequestParser() message_list_parser = (
message_list_parser.add_argument( reqparse.RequestParser()
"conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID" .add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID")
) .add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination")
message_list_parser.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination") .add_argument(
message_list_parser.add_argument( "limit",
"limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of messages to return" type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of messages to return",
)
) )
message_feedback_parser = reqparse.RequestParser() message_feedback_parser = (
message_feedback_parser.add_argument( reqparse.RequestParser()
"rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating" .add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating")
.add_argument("content", type=str, location="json", help="Feedback content")
) )
message_feedback_parser.add_argument("content", type=str, location="json", help="Feedback content")
feedback_list_parser = reqparse.RequestParser() feedback_list_parser = (
feedback_list_parser.add_argument("page", type=int, default=1, location="args", help="Page number") reqparse.RequestParser()
feedback_list_parser.add_argument( .add_argument("page", type=int, default=1, location="args", help="Page number")
"limit", type=int_range(1, 101), required=False, default=20, location="args", help="Number of feedbacks per page" .add_argument(
"limit",
type=int_range(1, 101),
required=False,
default=20,
location="args",
help="Number of feedbacks per page",
)
) )

View File

@ -42,32 +42,36 @@ from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define parsers for workflow APIs # Define parsers for workflow APIs
workflow_run_parser = reqparse.RequestParser() workflow_run_parser = (
workflow_run_parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") reqparse.RequestParser()
workflow_run_parser.add_argument("files", type=list, required=False, location="json") .add_argument("inputs", type=dict, required=True, nullable=False, location="json")
workflow_run_parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") .add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
)
workflow_log_parser = reqparse.RequestParser() workflow_log_parser = (
workflow_log_parser.add_argument("keyword", type=str, location="args") reqparse.RequestParser()
workflow_log_parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") .add_argument("keyword", type=str, location="args")
workflow_log_parser.add_argument("created_at__before", type=str, location="args") .add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
workflow_log_parser.add_argument("created_at__after", type=str, location="args") .add_argument("created_at__before", type=str, location="args")
workflow_log_parser.add_argument( .add_argument("created_at__after", type=str, location="args")
"created_by_end_user_session_id", .add_argument(
type=str, "created_by_end_user_session_id",
location="args", type=str,
required=False, location="args",
default=None, required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
) )
workflow_log_parser.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
workflow_log_parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
workflow_log_parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
workflow_run_fields = { workflow_run_fields = {
"id": fields.String, "id": fields.String,

Some files were not shown because too many files have changed in this diff Show More