mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
74fc026fe0
|
|
@ -6,11 +6,10 @@ cd web && pnpm install
|
|||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ trim_trailing_whitespace = false
|
|||
|
||||
# Matches multiple files with brace expansion notation
|
||||
# Set default charset
|
||||
[*.{js,tsx}]
|
||||
[*.{js,jsx,ts,tsx,mjs}]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ jobs:
|
|||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db
|
||||
db_postgres
|
||||
redis
|
||||
sandbox
|
||||
ssrf_proxy
|
||||
|
|
|
|||
|
|
@ -28,6 +28,11 @@ jobs:
|
|||
# Format code
|
||||
uv run ruff format ..
|
||||
|
||||
- name: count migration progress
|
||||
run: |
|
||||
cd api
|
||||
./cnt_base.sh
|
||||
|
||||
- name: ast-grep
|
||||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ concurrency:
|
|||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
db-migration-test:
|
||||
db-migration-test-postgres:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
|
|
@ -45,7 +45,7 @@ jobs:
|
|||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db
|
||||
db_postgres
|
||||
redis
|
||||
|
||||
- name: Prepare configs
|
||||
|
|
@ -57,3 +57,60 @@ jobs:
|
|||
env:
|
||||
DEBUG: true
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
||||
db-migration-test-mysql:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api
|
||||
- name: Ensure Offline migration are supported
|
||||
run: |
|
||||
# upgrade
|
||||
uv run --directory api flask db upgrade 'base:head' --sql
|
||||
# downgrade
|
||||
uv run --directory api flask db downgrade 'head:base' --sql
|
||||
|
||||
- name: Prepare middleware env for MySQL
|
||||
run: |
|
||||
cd docker
|
||||
cp middleware.env.example middleware.env
|
||||
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
|
||||
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db_mysql
|
||||
redis
|
||||
|
||||
- name: Prepare configs for MySQL
|
||||
run: |
|
||||
cd api
|
||||
cp .env.example .env
|
||||
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' .env
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
|
||||
|
||||
- name: Run DB Migration
|
||||
env:
|
||||
DEBUG: true
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
|
|
|||
|
|
@ -51,13 +51,13 @@ jobs:
|
|||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Vector Store (TiDB)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: docker/tidb/docker-compose.yaml
|
||||
services: |
|
||||
tidb
|
||||
tiflash
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
# compose-file: docker/tidb/docker-compose.yaml
|
||||
# services: |
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
|
|
@ -83,8 +83,8 @@ jobs:
|
|||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
- name: Check VDB Ready (TiDB)
|
||||
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
|
|
|
|||
|
|
@ -186,6 +186,8 @@ docker/volumes/couchbase/*
|
|||
docker/volumes/oceanbase/*
|
||||
docker/volumes/plugin_daemon/*
|
||||
docker/volumes/matrixone/*
|
||||
docker/volumes/mysql/*
|
||||
docker/volumes/seekdb/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@
|
|||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
|
|
|||
8
Makefile
8
Makefile
|
|
@ -70,6 +70,11 @@ type-check:
|
|||
@uv run --directory api --dev basedpyright
|
||||
@echo "✅ Type check complete"
|
||||
|
||||
test:
|
||||
@echo "🧪 Running backend unit tests..."
|
||||
@uv run --project api --dev dev/pytest/pytest_unit_tests.sh
|
||||
@echo "✅ Tests complete"
|
||||
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
|
|
@ -119,6 +124,7 @@ help:
|
|||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format and fix code with ruff"
|
||||
@echo " make type-check - Run type checking with basedpyright"
|
||||
@echo " make test - Run backend unit tests"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
@echo " make build-web - Build web Docker image"
|
||||
|
|
@ -128,4 +134,4 @@ help:
|
|||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test
|
||||
|
|
|
|||
|
|
@ -72,12 +72,15 @@ REDIS_CLUSTERS_PASSWORD=
|
|||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
|
||||
CELERY_BACKEND=redis
|
||||
# PostgreSQL database configuration
|
||||
|
||||
# Database configuration
|
||||
DB_TYPE=postgresql
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_DATABASE=dify
|
||||
|
||||
SQLALCHEMY_POOL_PRE_PING=true
|
||||
SQLALCHEMY_POOL_TIMEOUT=30
|
||||
|
||||
|
|
@ -159,12 +162,11 @@ SUPABASE_URL=your-server-url
|
|||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
# Set COOKIE_DOMAIN when the console frontend and API are on different subdomains.
|
||||
# Provide the registrable domain (e.g. example.com); leading dots are optional.
|
||||
# When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). Leading dots are optional.
|
||||
COOKIE_DOMAIN=
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
|
|
@ -175,6 +177,17 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
|||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
OCEANBASE_VECTOR_USER=root@test
|
||||
OCEANBASE_VECTOR_PASSWORD=difyai123456
|
||||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||
OCEANBASE_FULLTEXT_PARSER=ik
|
||||
SEEKDB_MEMORY_LIMIT=2G
|
||||
|
||||
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
|
||||
QDRANT_URL=http://localhost:6333
|
||||
QDRANT_API_KEY=difyai123456
|
||||
|
|
@ -340,15 +353,6 @@ LINDORM_PASSWORD=admin
|
|||
LINDORM_USING_UGC=True
|
||||
LINDORM_QUERY_TIMEOUT=1
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
OCEANBASE_VECTOR_USER=root@test
|
||||
OCEANBASE_VECTOR_PASSWORD=difyai123456
|
||||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||
|
||||
# AlibabaCloud MySQL Vector configuration
|
||||
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
|
||||
ALIBABACLOUD_MYSQL_PORT=3306
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@
|
|||
```bash
|
||||
cd ../docker
|
||||
cp middleware.env.example middleware.env
|
||||
# change the profile to other vector database if you are not using weaviate
|
||||
docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d
|
||||
# change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
|
||||
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
cd ../api
|
||||
```
|
||||
|
||||
|
|
@ -26,6 +26,10 @@
|
|||
cp .env.example .env
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
|
||||
|
||||
1. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
|
|
@ -80,7 +84,7 @@
|
|||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
|
||||
```
|
||||
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||
"""
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
dify_app.config["RESTX_INCLUDE_ALL_MODELS"] = True
|
||||
|
||||
# add before request hook
|
||||
@dify_app.before_request
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
|
||||
for pattern in "Base" "TypeBase"; do
|
||||
printf "%s " "$pattern"
|
||||
grep "($pattern):" -r --include='*.py' --exclude-dir=".venv" --exclude-dir="tests" . | wc -l
|
||||
done
|
||||
|
|
@ -77,10 +77,6 @@ class AppExecutionConfig(BaseSettings):
|
|||
description="Maximum number of concurrent active requests per app (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
APP_DAILY_RATE_LIMIT: NonNegativeInt = Field(
|
||||
description="Maximum number of requests per app per day",
|
||||
default=5000,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutionSandboxConfig(BaseSettings):
|
||||
|
|
@ -1086,7 +1082,7 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
|||
)
|
||||
TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive credential refresh threshold in seconds",
|
||||
default=180,
|
||||
default=60 * 60,
|
||||
)
|
||||
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive subscription refresh threshold in seconds",
|
||||
|
|
|
|||
|
|
@ -105,6 +105,12 @@ class KeywordStoreConfig(BaseSettings):
|
|||
|
||||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
# Database type selector
|
||||
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
|
||||
description="Database type to use. OceanBase is MySQL-compatible.",
|
||||
default="postgresql",
|
||||
)
|
||||
|
||||
DB_HOST: str = Field(
|
||||
description="Hostname or IP address of the database server.",
|
||||
default="localhost",
|
||||
|
|
@ -140,10 +146,10 @@ class DatabaseConfig(BaseSettings):
|
|||
default="",
|
||||
)
|
||||
|
||||
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
|
||||
description="Database URI scheme for SQLAlchemy connection.",
|
||||
default="postgresql",
|
||||
)
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
|
||||
return "postgresql" if self.DB_TYPE == "postgresql" else "mysql+pymysql"
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
|
|
@ -204,15 +210,15 @@ class DatabaseConfig(BaseSettings):
|
|||
# Parse DB_EXTRAS for 'options'
|
||||
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
|
||||
options = db_extras_dict.get("options", "")
|
||||
# Always include timezone
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
# Merge user options and timezone
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
|
||||
connect_args = {"options": merged_options}
|
||||
connect_args = {}
|
||||
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
|
||||
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
connect_args = {"options": merged_options}
|
||||
|
||||
return {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -104,14 +104,11 @@ class BaseApiKeyResource(Resource):
|
|||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
def delete(self, resource_id, api_key_id):
|
||||
def delete(self, resource_id: str, api_key_id: str):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
api_key_id = str(api_key_id)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
|
|||
|
|
@ -5,18 +5,20 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
||||
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
|
||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/app/prompt-templates")
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
@api.doc("get_advanced_prompt_templates")
|
||||
@api.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
||||
.add_argument("has_context", type=str, default="true", location="args", help="Whether has context")
|
||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
|
||||
)
|
||||
|
|
@ -25,13 +27,6 @@ class AdvancedPromptTemplateList(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args")
|
||||
.add_argument("model_mode", type=str, required=True, location="args")
|
||||
.add_argument("has_context", type=str, required=False, default="true", location="args")
|
||||
.add_argument("model_name", type=str, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
|
|
|||
|
|
@ -8,17 +8,19 @@ from libs.login import login_required
|
|||
from models.model import AppMode
|
||||
from services.agent_service import AgentService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
|
||||
class AgentLogApi(Resource):
|
||||
@api.doc("get_agent_logs")
|
||||
@api.doc(description="Get agent execution logs for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("message_id", type=str, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
|
|
@ -27,12 +29,6 @@ class AgentLogApi(Resource):
|
|||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||
def get(self, app_model):
|
||||
"""Get agent logs"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=uuid_value, required=True, location="args")
|
||||
.add_argument("conversation_id", type=uuid_value, required=True, location="args")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
||||
|
|
|
|||
|
|
@ -251,6 +251,13 @@ class AnnotationExportApi(Resource):
|
|||
return response, 200
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@api.doc("update_delete_annotation")
|
||||
|
|
@ -259,6 +266,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
@api.response(200, "Annotation updated successfully", annotation_fields)
|
||||
@api.response(204, "Annotation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -268,11 +276,6 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
def post(self, app_id, annotation_id):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
return annotation
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import uuid
|
|||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, abort
|
||||
from werkzeug.exceptions import BadRequest, abort
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -12,14 +12,16 @@ from controllers.console.wraps import (
|
|||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import App
|
||||
from models import App, Workflow
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
|
@ -106,6 +108,35 @@ class AppListApi(Resource):
|
|||
if str(app.id) in res:
|
||||
app.access_mode = res[str(app.id)].access_mode
|
||||
|
||||
workflow_capable_app_ids = [
|
||||
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
|
||||
]
|
||||
draft_trigger_app_ids: set[str] = set()
|
||||
if workflow_capable_app_ids:
|
||||
draft_workflows = (
|
||||
db.session.execute(
|
||||
select(Workflow).where(
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
Workflow.app_id.in_(workflow_capable_app_ids),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
trigger_node_types = {
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
||||
return marshal(app_pagination, app_pagination_fields), 200
|
||||
|
||||
@api.doc("create_app")
|
||||
|
|
@ -220,10 +251,8 @@ class AppApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
# Construct ArgsDict from parsed arguments
|
||||
from services.app_service import AppService as AppServiceType
|
||||
|
||||
args_dict: AppServiceType.ArgsDict = {
|
||||
args_dict: AppService.ArgsDict = {
|
||||
"name": args["name"],
|
||||
"description": args.get("description", ""),
|
||||
"icon_type": args.get("icon_type", ""),
|
||||
|
|
@ -353,12 +382,15 @@ class AppExportApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@api.doc("check_app_name")
|
||||
@api.doc(description="Check if app name is available")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check"))
|
||||
@api.expect(parser)
|
||||
@api.response(200, "Name availability checked")
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -367,7 +399,6 @@ class AppNameApi(Resource):
|
|||
@marshal_with(app_detail_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
|
|
@ -455,15 +486,11 @@ class AppApiStatus(Resource):
|
|||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
|
|
@ -18,9 +19,23 @@ from services.feature_service import FeatureService
|
|||
|
||||
from .. import console_ns
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
.add_argument("yaml_content", type=str, location="json")
|
||||
.add_argument("yaml_url", type=str, location="json")
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=str, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("app_id", type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@api.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -30,18 +45,6 @@ class AppImportApi(Resource):
|
|||
def post(self):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
.add_argument("yaml_content", type=str, location="json")
|
||||
.add_argument("yaml_url", type=str, location="json")
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=str, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("app_id", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
|
|
|
|||
|
|
@ -3,11 +3,10 @@ from typing import cast
|
|||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
|
|
@ -48,15 +47,12 @@ class ModelConfigResource(Resource):
|
|||
@api.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model):
|
||||
"""Modify app model config"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_tenant_id,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,15 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -76,17 +81,13 @@ class AppSite(Resource):
|
|||
@api.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_fields)
|
||||
def post(self, app_model):
|
||||
args = parse_app_site_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
|
@ -130,16 +131,12 @@ class AppSiteAccessTokenReset(Resource):
|
|||
@api.response(404, "App or site not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.helper import DatetimeString, convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode, Message
|
||||
from models import AppMode
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
|
|
@ -44,8 +44,9 @@ class DailyMessageStatistic(Resource):
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(*) AS message_count
|
||||
FROM
|
||||
messages
|
||||
|
|
@ -80,16 +81,19 @@ WHERE
|
|||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
class DailyConversationStatistic(Resource):
|
||||
@api.doc("get_daily_conversation_statistics")
|
||||
@api.doc(description="Get daily conversation statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily conversation statistics retrieved successfully",
|
||||
|
|
@ -102,12 +106,18 @@ class DailyConversationStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(DISTINCT conversation_id) AS conversation_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
|
|
@ -115,30 +125,21 @@ class DailyConversationStatistic(Resource):
|
|||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
stmt = (
|
||||
sa.select(
|
||||
sa.func.date(
|
||||
sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
|
||||
).label("date"),
|
||||
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
|
||||
)
|
||||
.select_from(Message)
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
|
||||
)
|
||||
|
||||
if start_datetime_utc:
|
||||
stmt = stmt.where(Message.created_at >= start_datetime_utc)
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if end_datetime_utc:
|
||||
stmt = stmt.where(Message.created_at < end_datetime_utc)
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
stmt = stmt.group_by("date").order_by("date")
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(stmt, {"tz": account.timezone})
|
||||
for row in rs:
|
||||
response_data.append({"date": str(row.date), "conversation_count": row.conversation_count})
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
|
@ -148,11 +149,7 @@ class DailyTerminalsStatistic(Resource):
|
|||
@api.doc("get_daily_terminals_statistics")
|
||||
@api.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily terminal statistics retrieved successfully",
|
||||
|
|
@ -165,15 +162,11 @@ class DailyTerminalsStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
|
||||
FROM
|
||||
messages
|
||||
|
|
@ -213,11 +206,7 @@ class DailyTokenCostStatistic(Resource):
|
|||
@api.doc("get_daily_token_cost_statistics")
|
||||
@api.doc(description="Get daily token cost statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily token cost statistics retrieved successfully",
|
||||
|
|
@ -230,15 +219,11 @@ class DailyTokenCostStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
|
||||
SUM(total_price) AS total_price
|
||||
FROM
|
||||
|
|
@ -281,11 +266,7 @@ class AverageSessionInteractionStatistic(Resource):
|
|||
@api.doc("get_average_session_interaction_statistics")
|
||||
@api.doc(description="Get average session interaction statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Average session interaction statistics retrieved successfully",
|
||||
|
|
@ -298,15 +279,11 @@ class AverageSessionInteractionStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
AVG(subquery.message_count) AS interactions
|
||||
FROM
|
||||
(
|
||||
|
|
@ -365,11 +342,7 @@ class UserSatisfactionRateStatistic(Resource):
|
|||
@api.doc("get_user_satisfaction_rate_statistics")
|
||||
@api.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"User satisfaction rate statistics retrieved successfully",
|
||||
|
|
@ -382,15 +355,11 @@ class UserSatisfactionRateStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(m.id) AS message_count,
|
||||
COUNT(mf.id) AS feedback_count
|
||||
FROM
|
||||
|
|
@ -439,11 +408,7 @@ class AverageResponseTimeStatistic(Resource):
|
|||
@api.doc("get_average_response_time_statistics")
|
||||
@api.doc(description="Get average response time statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Average response time statistics retrieved successfully",
|
||||
|
|
@ -456,15 +421,11 @@ class AverageResponseTimeStatistic(Resource):
|
|||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
AVG(provider_response_latency) AS latency
|
||||
FROM
|
||||
messages
|
||||
|
|
@ -504,11 +465,7 @@ class TokensPerSecondStatistic(Resource):
|
|||
@api.doc("get_tokens_per_second_statistics")
|
||||
@api.doc(description="Get tokens per second statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Tokens per second statistics retrieved successfully",
|
||||
|
|
@ -520,16 +477,11 @@ class TokensPerSecondStatistic(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
CASE
|
||||
WHEN SUM(provider_response_latency) = 0 THEN 0
|
||||
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
||||
|
|
|
|||
|
|
@ -586,6 +586,13 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
return workflow_node_execution
|
||||
|
||||
|
||||
parser_publish = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
|
||||
class PublishedWorkflowApi(Resource):
|
||||
@api.doc("get_published_workflow")
|
||||
|
|
@ -610,6 +617,7 @@ class PublishedWorkflowApi(Resource):
|
|||
# return workflow, if not found, return None
|
||||
return workflow
|
||||
|
||||
@api.expect(parser_publish)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -620,12 +628,8 @@ class PublishedWorkflowApi(Resource):
|
|||
Publish workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
args = parser_publish.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
|
|
@ -680,6 +684,9 @@ class DefaultBlockConfigsApi(Resource):
|
|||
return workflow_service.get_default_block_configs()
|
||||
|
||||
|
||||
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultBlockConfigApi(Resource):
|
||||
@api.doc("get_default_block_config")
|
||||
|
|
@ -687,6 +694,7 @@ class DefaultBlockConfigApi(Resource):
|
|||
@api.doc(params={"app_id": "Application ID", "block_type": "Block type"})
|
||||
@api.response(200, "Default block configuration retrieved successfully")
|
||||
@api.response(404, "Block type not found")
|
||||
@api.expect(parser_block)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -696,8 +704,7 @@ class DefaultBlockConfigApi(Resource):
|
|||
"""
|
||||
Get default block config
|
||||
"""
|
||||
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_block.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
|
||||
|
|
@ -713,8 +720,18 @@ class DefaultBlockConfigApi(Resource):
|
|||
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
parser_convert = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
|
||||
class ConvertToWorkflowApi(Resource):
|
||||
@api.expect(parser_convert)
|
||||
@api.doc("convert_to_workflow")
|
||||
@api.doc(description="Convert application to workflow mode")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
|
|
@ -735,14 +752,7 @@ class ConvertToWorkflowApi(Resource):
|
|||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if request.data:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_convert.parse_args()
|
||||
else:
|
||||
args = {}
|
||||
|
||||
|
|
@ -756,8 +766,18 @@ class ConvertToWorkflowApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser_workflows = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@api.expect(parser_workflows)
|
||||
@api.doc("get_all_published_workflows")
|
||||
@api.doc(description="Get all published workflows for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
|
|
@ -774,16 +794,9 @@ class PublishedAllWorkflowApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
page = int(args.get("page", 1))
|
||||
limit = int(args.get("limit", 10))
|
||||
args = parser_workflows.parse_args()
|
||||
page = args["page"]
|
||||
limit = args["limit"]
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
|
||||
|
|
@ -970,8 +983,9 @@ class DraftWorkflowTriggerRunApi(Resource):
|
|||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"node_id", type=str, required=True, location="json", nullable=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
workflow_service = WorkflowService()
|
||||
|
|
@ -1123,8 +1137,9 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_ids", type=list, required=True, location="json", nullable=False)
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"node_ids", type=list, required=True, location="json", nullable=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
node_ids = args["node_ids"]
|
||||
workflow_service = WorkflowService()
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
import logging
|
||||
from typing import NoReturn
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import NoReturn, ParamSpec, TypeVar
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
|
|
@ -21,8 +22,8 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIAB
|
|||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode
|
||||
from libs.login import login_required
|
||||
from models import App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
|
@ -140,8 +141,11 @@ _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
|||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
def _api_prerequisite(f):
|
||||
|
||||
def _api_prerequisite(f: Callable[P, R]):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
|
|
@ -155,11 +159,10 @@ def _api_prerequisite(f):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def wrapper(*args, **kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
|
@ -167,6 +170,7 @@ def _api_prerequisite(f):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
|
||||
class WorkflowVariableCollectionApi(Resource):
|
||||
@api.expect(_create_pagination_parser())
|
||||
@api.doc("get_workflow_variables")
|
||||
@api.doc(description="Get draft workflow variables")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
|
|
|
|||
|
|
@ -30,23 +30,25 @@ def _parse_workflow_run_list_args():
|
|||
Returns:
|
||||
Parsed arguments containing last_id, limit, status, and triggered_from filters
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
@ -58,28 +60,30 @@ def _parse_workflow_run_count_args():
|
|||
Returns:
|
||||
Parsed arguments containing status, time_range, and triggered_from filters
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"time_range",
|
||||
type=time_duration,
|
||||
location="args",
|
||||
required=False,
|
||||
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"time_range",
|
||||
type=time_duration,
|
||||
location="args",
|
||||
required=False,
|
||||
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ import logging
|
|||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
|
|
@ -29,8 +29,7 @@ class WebhookTriggerApi(Resource):
|
|||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
node_id = str(args["node_id"])
|
||||
|
|
@ -95,19 +94,19 @@ class AppTriggerEnableApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
trigger_id = args["trigger_id"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
|
|
@ -39,12 +39,10 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
|
|
@ -65,12 +63,10 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
def delete(self, binding_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import logging
|
|||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
|
@ -42,11 +42,9 @@ class OAuthDataSource(Resource):
|
|||
)
|
||||
@api.response(400, "Invalid provider")
|
||||
@api.response(403, "Admin privileges required")
|
||||
@is_admin_or_owner_required
|
||||
def get(self, provider: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
import base64
|
||||
|
||||
from controllers.console import console_ns
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -41,3 +44,37 @@ class Invoices(Resource):
|
|||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_invoices(current_user.email, current_tenant_id)
|
||||
|
||||
|
||||
@console_ns.route("/billing/partners/<string:partner_key>/tenants")
|
||||
class PartnerTenants(Resource):
|
||||
@api.doc("sync_partner_tenants_bindings")
|
||||
@api.doc(description="Sync partner tenants bindings")
|
||||
@api.doc(params={"partner_key": "Partner key"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"SyncPartnerTenantsBindingsRequest",
|
||||
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Tenants synced to partner successfully")
|
||||
@api.response(400, "Invalid partner information")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def put(self, partner_key: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
click_id = args["click_id"]
|
||||
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
|
||||
except Exception:
|
||||
raise BadRequest("Invalid partner_key")
|
||||
|
||||
if not click_id or not decoded_partner_key or not current_user.id:
|
||||
raise BadRequest("Invalid partner information")
|
||||
|
||||
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from controllers.console.wraps import (
|
|||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
|
|
@ -753,13 +754,11 @@ class DatasetApiKeyApi(Resource):
|
|||
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_fields)
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
|
|
@ -794,15 +793,11 @@ class DatasetApiDeleteApi(Resource):
|
|||
@api.response(204, "API key deleted successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, api_key_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
api_key_id = str(api_key_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
.where(
|
||||
|
|
|
|||
|
|
@ -162,6 +162,7 @@ class DatasetDocumentListApi(Resource):
|
|||
"keyword": "Search keyword",
|
||||
"sort": "Sort order (default: -created_at)",
|
||||
"fetch": "Fetch full details (default: false)",
|
||||
"status": "Filter documents by display status",
|
||||
}
|
||||
)
|
||||
@api.response(200, "Documents retrieved successfully")
|
||||
|
|
@ -175,6 +176,7 @@ class DatasetDocumentListApi(Resource):
|
|||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
sort = request.args.get("sort", default="-created_at", type=str)
|
||||
status = request.args.get("status", default=None, type=str)
|
||||
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
|
||||
try:
|
||||
fetch_val = request.args.get("fetch", default="false")
|
||||
|
|
@ -203,6 +205,9 @@ class DatasetDocumentListApi(Resource):
|
|||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
|
||||
|
||||
if status:
|
||||
query = DocumentService.apply_display_status_filter(query, status)
|
||||
|
||||
if search:
|
||||
search = f"%{search}%"
|
||||
query = query.where(Document.name.like(search))
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
|||
import services
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.dataset_service import DatasetService
|
||||
|
|
@ -200,12 +200,10 @@ class ExternalDatasetCreateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from flask_restx import Resource, reqparse
|
|||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -121,8 +121,16 @@ class DatasourceOAuthCallback(Resource):
|
|||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
parser_datasource = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
||||
class DatasourceAuth(Resource):
|
||||
@api.expect(parser_datasource)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -130,14 +138,7 @@ class DatasourceAuth(Resource):
|
|||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_datasource.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
|
|
@ -168,8 +169,14 @@ class DatasourceAuth(Resource):
|
|||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
parser_datasource_delete = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@api.expect(parser_datasource_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -181,10 +188,7 @@ class DatasourceAuthDeleteApi(Resource):
|
|||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_datasource_delete.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
|
|
@ -195,8 +199,17 @@ class DatasourceAuthDeleteApi(Resource):
|
|||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_datasource_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@api.expect(parser_datasource_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -205,13 +218,7 @@ class DatasourceAuthUpdateApi(Resource):
|
|||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_datasource_update.parse_args()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
|
|
@ -251,8 +258,16 @@ class DatasourceHardCodeAuthListApi(Resource):
|
|||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
parser_datasource_custom = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@api.expect(parser_datasource_custom)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -260,12 +275,7 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_datasource_custom.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
|
|
@ -291,8 +301,12 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@api.expect(parser_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -300,8 +314,7 @@ class DatasourceAuthDefaultApi(Resource):
|
|||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_default.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
|
|
@ -312,8 +325,16 @@ class DatasourceAuthDefaultApi(Resource):
|
|||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_update_name = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@api.expect(parser_update_name)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -321,12 +342,7 @@ class DatasourceUpdateProviderNameApi(Resource):
|
|||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_update_name.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_restx import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
|
|
@ -12,17 +12,21 @@ from models import Account
|
|||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class Parser(BaseModel):
|
||||
inputs: dict
|
||||
datasource_type: str
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@api.expect(parser)
|
||||
@api.expect(console_ns.models[Parser.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -34,15 +38,10 @@ class DataSourceContentPreviewApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
args = Parser.model_validate(api.payload)
|
||||
|
||||
inputs = args.inputs
|
||||
datasource_type = args.datasource_type
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
preview_content = rag_pipeline_service.run_datasource_node_preview(
|
||||
pipeline=pipeline,
|
||||
|
|
@ -51,6 +50,6 @@ class DataSourceContentPreviewApi(Resource):
|
|||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
is_published=True,
|
||||
credential_id=args.get("credential_id"),
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
return preview_content, 200
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -21,12 +21,11 @@ class RagPipelineImportApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
|
|
@ -71,12 +70,10 @@ class RagPipelineImportConfirmApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self, import_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# Check user role first
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -98,12 +95,9 @@ class RagPipelineImportCheckDependenciesApi(Resource):
|
|||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
|
@ -117,12 +111,9 @@ class RagPipelineExportApi(Resource):
|
|||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, pipeline: Pipeline):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
|||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
ConversationCompletedError,
|
||||
DraftWorkflowNotExist,
|
||||
|
|
@ -148,8 +148,12 @@ class DraftRagPipelineApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
@api.expect(parser_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -162,8 +166,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_run.parse_args()
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate_single_iteration(
|
||||
|
|
@ -184,9 +187,11 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
@api.expect(parser_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
|
|
@ -194,11 +199,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_run.parse_args()
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate_single_loop(
|
||||
|
|
@ -217,11 +219,22 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
parser_draft_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
|
||||
class DraftRagPipelineRunApi(Resource):
|
||||
@api.expect(parser_draft_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
"""
|
||||
|
|
@ -229,17 +242,8 @@ class DraftRagPipelineRunApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_draft_run.parse_args()
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate(
|
||||
|
|
@ -255,11 +259,25 @@ class DraftRagPipelineRunApi(Resource):
|
|||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
parser_published_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
||||
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
|
||||
class PublishedRagPipelineRunApi(Resource):
|
||||
@api.expect(parser_published_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
"""
|
||||
|
|
@ -267,20 +285,8 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
||||
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_published_run.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
|
||||
|
|
@ -381,11 +387,21 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
#
|
||||
# return result
|
||||
#
|
||||
parser_rag_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@api.expect(parser_rag_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
|
|
@ -393,16 +409,8 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_rag_run.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
|
|
@ -429,8 +437,10 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
@api.expect(parser_rag_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
|
|
@ -439,16 +449,8 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_rag_run.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
|
|
@ -473,10 +475,17 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_run_api = reqparse.RequestParser().add_argument(
|
||||
"inputs", type=dict, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftNodeRunApi(Resource):
|
||||
@api.expect(parser_run_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_run_node_execution_fields)
|
||||
|
|
@ -486,13 +495,8 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"inputs", type=dict, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_run_api.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs == None:
|
||||
|
|
@ -513,6 +517,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||
class RagPipelineTaskStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, task_id: str):
|
||||
|
|
@ -521,8 +526,6 @@ class RagPipelineTaskStopApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
|
|
@ -534,6 +537,7 @@ class PublishedRagPipelineApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
|
|
@ -541,9 +545,6 @@ class PublishedRagPipelineApi(Resource):
|
|||
Get published pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
if not pipeline.is_published:
|
||||
return None
|
||||
# fetch published workflow by pipeline
|
||||
|
|
@ -556,6 +557,7 @@ class PublishedRagPipelineApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
"""
|
||||
|
|
@ -563,9 +565,6 @@ class PublishedRagPipelineApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
with Session(db.engine) as session:
|
||||
pipeline = session.merge(pipeline)
|
||||
|
|
@ -592,38 +591,33 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# Get default block configs
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
return rag_pipeline_service.get_default_block_configs()
|
||||
|
||||
|
||||
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
@api.expect(parser_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline, block_type: str):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_default.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
|
||||
|
|
@ -639,11 +633,22 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
|||
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
parser_wf = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
|
||||
class PublishedAllRagPipelineApi(Resource):
|
||||
@api.expect(parser_wf)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_pagination_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
|
|
@ -651,19 +656,10 @@ class PublishedAllRagPipelineApi(Resource):
|
|||
Get published workflows
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
page = int(args.get("page", 1))
|
||||
limit = int(args.get("limit", 10))
|
||||
args = parser_wf.parse_args()
|
||||
page = args["page"]
|
||||
limit = args["limit"]
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
|
||||
|
|
@ -691,11 +687,20 @@ class PublishedAllRagPipelineApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser_wf_id = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
||||
class RagPipelineByIdApi(Resource):
|
||||
@api.expect(parser_wf_id)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_fields)
|
||||
def patch(self, pipeline: Pipeline, workflow_id: str):
|
||||
|
|
@ -704,22 +709,14 @@ class RagPipelineByIdApi(Resource):
|
|||
"""
|
||||
# Check permission
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_wf_id.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
|
|
@ -752,8 +749,12 @@ class RagPipelineByIdApi(Resource):
|
|||
return workflow
|
||||
|
||||
|
||||
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
|
||||
class PublishedRagPipelineSecondStepApi(Resource):
|
||||
@api.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -763,8 +764,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
|||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
|
|
@ -777,6 +777,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
|
||||
class PublishedRagPipelineFirstStepApi(Resource):
|
||||
@api.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -786,8 +787,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
|||
"""
|
||||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
|
|
@ -800,6 +800,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
|
||||
class DraftRagPipelineFirstStepApi(Resource):
|
||||
@api.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -809,8 +810,7 @@ class DraftRagPipelineFirstStepApi(Resource):
|
|||
"""
|
||||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
|
|
@ -823,6 +823,7 @@ class DraftRagPipelineFirstStepApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
|
||||
class DraftRagPipelineSecondStepApi(Resource):
|
||||
@api.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -832,8 +833,7 @@ class DraftRagPipelineSecondStepApi(Resource):
|
|||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
|
|
@ -845,8 +845,16 @@ class DraftRagPipelineSecondStepApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser_wf_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
|
||||
class RagPipelineWorkflowRunListApi(Resource):
|
||||
@api.expect(parser_wf_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -856,12 +864,7 @@ class RagPipelineWorkflowRunListApi(Resource):
|
|||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_wf_run.parse_args()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
|
||||
|
|
@ -961,8 +964,18 @@ class RagPipelineTransformApi(Resource):
|
|||
return result
|
||||
|
||||
|
||||
parser_var = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info", type=dict, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("start_node_title", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
|
||||
class RagPipelineDatasourceVariableApi(Resource):
|
||||
@api.expect(parser_var)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -974,14 +987,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
|||
Set datasource variables
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info", type=dict, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("start_node_title", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_var.parse_args()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_node_execution = rag_pipeline_service.set_datasource_variables(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
from libs.login import current_user, login_required
|
||||
|
|
@ -35,15 +35,18 @@ recommended_app_list_fields = {
|
|||
}
|
||||
|
||||
|
||||
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@api.expect(parser_apps)
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(recommended_app_list_fields)
|
||||
def get(self):
|
||||
# language args
|
||||
parser = reqparse.RequestParser().add_argument("language", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_apps.parse_args()
|
||||
|
||||
language = args.get("language")
|
||||
if language and language in languages:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from controllers.common.errors import (
|
|||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console import api
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -36,12 +37,15 @@ class RemoteFileInfoApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/upload")
|
||||
class RemoteFileUploadApi(Resource):
|
||||
@api.expect(parser_upload)
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
|
||||
args = parser.parse_args()
|
||||
args = parser_upload.parse_args()
|
||||
|
||||
url = args["url"]
|
||||
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ class SetupApi(Resource):
|
|||
"email": fields.String(required=True, description="Admin email address"),
|
||||
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
|
||||
"password": fields.String(required=True, description="Admin password"),
|
||||
"language": fields.String(required=False, description="Admin language"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ from flask import request
|
|||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import Tag
|
||||
|
|
@ -16,6 +16,19 @@ def _validate_name(name):
|
|||
return name
|
||||
|
||||
|
||||
parser_tags = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tags")
|
||||
class TagListApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -30,6 +43,7 @@ class TagListApi(Resource):
|
|||
|
||||
return tags, 200
|
||||
|
||||
@api.expect(parser_tags)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -39,20 +53,7 @@ class TagListApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_tags.parse_args()
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
|
@ -60,8 +61,14 @@ class TagListApi(Resource):
|
|||
return response, 200
|
||||
|
||||
|
||||
parser_tag_id = reqparse.RequestParser().add_argument(
|
||||
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@api.expect(parser_tag_id)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -72,10 +79,7 @@ class TagUpdateDeleteApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_tag_id.parse_args()
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
|
@ -87,20 +91,26 @@ class TagUpdateDeleteApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, tag_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
parser_create = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@api.expect(parser_create)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -110,26 +120,23 @@ class TagBindingCreateApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
|
||||
)
|
||||
.add_argument(
|
||||
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
|
||||
)
|
||||
.add_argument(
|
||||
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_create.parse_args()
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_remove = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@api.expect(parser_remove)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -139,15 +146,7 @@ class TagBindingDeleteApi(Resource):
|
|||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
.add_argument(
|
||||
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_remove.parse_args()
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
|
|
|||
|
|
@ -11,16 +11,16 @@ from . import api, console_ns
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"current_version", type=str, required=True, location="args", help="Current application version"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/version")
|
||||
class VersionApi(Resource):
|
||||
@api.doc("check_version_update")
|
||||
@api.doc(description="Check for application version updates")
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
"current_version", type=str, required=True, location="args", help="Current application version"
|
||||
)
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
|
|
@ -37,7 +37,6 @@ class VersionApi(Resource):
|
|||
)
|
||||
def get(self):
|
||||
"""Check for application version updates"""
|
||||
parser = reqparse.RequestParser().add_argument("current_version", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
EmailChangeLimitError,
|
||||
|
|
@ -43,8 +43,19 @@ from services.billing_service import BillingService
|
|||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
|
||||
def _init_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
parser.add_argument("invitation_code", type=str, location="json")
|
||||
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
|
||||
"timezone", type=timezone, required=True, location="json"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@console_ns.route("/account/init")
|
||||
class AccountInitApi(Resource):
|
||||
@api.expect(_init_parser())
|
||||
@setup_required
|
||||
@login_required
|
||||
def post(self):
|
||||
|
|
@ -53,14 +64,7 @@ class AccountInitApi(Resource):
|
|||
if account.status == "active":
|
||||
raise AccountAlreadyInitedError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
parser.add_argument("invitation_code", type=str, location="json")
|
||||
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
|
||||
"timezone", type=timezone, required=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = _init_parser().parse_args()
|
||||
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
if not args["invitation_code"]:
|
||||
|
|
@ -106,16 +110,19 @@ class AccountProfileApi(Resource):
|
|||
return current_user
|
||||
|
||||
|
||||
parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/name")
|
||||
class AccountNameApi(Resource):
|
||||
@api.expect(parser_name)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_name.parse_args()
|
||||
|
||||
# Validate account name length
|
||||
if len(args["name"]) < 3 or len(args["name"]) > 30:
|
||||
|
|
@ -126,68 +133,80 @@ class AccountNameApi(Resource):
|
|||
return updated_account
|
||||
|
||||
|
||||
parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@api.expect(parser_avatar)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_avatar.parse_args()
|
||||
|
||||
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_interface = reqparse.RequestParser().add_argument(
|
||||
"interface_language", type=supported_language, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-language")
|
||||
class AccountInterfaceLanguageApi(Resource):
|
||||
@api.expect(parser_interface)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"interface_language", type=supported_language, required=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_interface.parse_args()
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_theme = reqparse.RequestParser().add_argument(
|
||||
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-theme")
|
||||
class AccountInterfaceThemeApi(Resource):
|
||||
@api.expect(parser_theme)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_theme.parse_args()
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
|
||||
|
||||
return updated_account
|
||||
|
||||
|
||||
parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/timezone")
|
||||
class AccountTimezoneApi(Resource):
|
||||
@api.expect(parser_timezone)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_timezone.parse_args()
|
||||
|
||||
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
|
||||
if args["timezone"] not in pytz.all_timezones:
|
||||
|
|
@ -198,21 +217,24 @@ class AccountTimezoneApi(Resource):
|
|||
return updated_account
|
||||
|
||||
|
||||
parser_pw = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("password", type=str, required=False, location="json")
|
||||
.add_argument("new_password", type=str, required=True, location="json")
|
||||
.add_argument("repeat_new_password", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/password")
|
||||
class AccountPasswordApi(Resource):
|
||||
@api.expect(parser_pw)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("password", type=str, required=False, location="json")
|
||||
.add_argument("new_password", type=str, required=True, location="json")
|
||||
.add_argument("repeat_new_password", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_pw.parse_args()
|
||||
|
||||
if args["new_password"] != args["repeat_new_password"]:
|
||||
raise RepeatPasswordNotMatchError()
|
||||
|
|
@ -294,20 +316,23 @@ class AccountDeleteVerifyApi(Resource):
|
|||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
parser_delete = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/delete")
|
||||
class AccountDeleteApi(Resource):
|
||||
@api.expect(parser_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_delete.parse_args()
|
||||
|
||||
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
|
||||
raise InvalidAccountDeletionCodeError()
|
||||
|
|
@ -317,16 +342,19 @@ class AccountDeleteApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_feedback = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("feedback", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/delete/feedback")
|
||||
class AccountDeleteUpdateFeedbackApi(Resource):
|
||||
@api.expect(parser_feedback)
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("feedback", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_feedback.parse_args()
|
||||
|
||||
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
|
||||
|
||||
|
|
@ -351,6 +379,14 @@ class EducationVerifyApi(Resource):
|
|||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||
|
||||
|
||||
parser_edu = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("institution", type=str, required=True, location="json")
|
||||
.add_argument("role", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/education")
|
||||
class EducationApi(Resource):
|
||||
status_fields = {
|
||||
|
|
@ -360,6 +396,7 @@ class EducationApi(Resource):
|
|||
"allow_refresh": fields.Boolean,
|
||||
}
|
||||
|
||||
@api.expect(parser_edu)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -368,13 +405,7 @@ class EducationApi(Resource):
|
|||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("institution", type=str, required=True, location="json")
|
||||
.add_argument("role", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_edu.parse_args()
|
||||
|
||||
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
|
||||
|
||||
|
|
@ -394,6 +425,14 @@ class EducationApi(Resource):
|
|||
return res
|
||||
|
||||
|
||||
parser_autocomplete = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keywords", type=str, required=True, location="args")
|
||||
.add_argument("page", type=int, required=False, location="args", default=0)
|
||||
.add_argument("limit", type=int, required=False, location="args", default=20)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/education/autocomplete")
|
||||
class EducationAutoCompleteApi(Resource):
|
||||
data_fields = {
|
||||
|
|
@ -402,6 +441,7 @@ class EducationAutoCompleteApi(Resource):
|
|||
"has_next": fields.Boolean,
|
||||
}
|
||||
|
||||
@api.expect(parser_autocomplete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -409,33 +449,30 @@ class EducationAutoCompleteApi(Resource):
|
|||
@cloud_edition_billing_enabled
|
||||
@marshal_with(data_fields)
|
||||
def get(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keywords", type=str, required=True, location="args")
|
||||
.add_argument("page", type=int, required=False, location="args", default=0)
|
||||
.add_argument("limit", type=int, required=False, location="args", default=20)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_autocomplete.parse_args()
|
||||
|
||||
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
|
||||
|
||||
|
||||
parser_change_email = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
.add_argument("phase", type=str, required=False, location="json")
|
||||
.add_argument("token", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email")
|
||||
class ChangeEmailSendEmailApi(Resource):
|
||||
@api.expect(parser_change_email)
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
.add_argument("phase", type=str, required=False, location="json")
|
||||
.add_argument("token", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_change_email.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
|
|
@ -470,20 +507,23 @@ class ChangeEmailSendEmailApi(Resource):
|
|||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
parser_validity = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/validity")
|
||||
class ChangeEmailCheckApi(Resource):
|
||||
@api.expect(parser_validity)
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_validity.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
|
||||
|
|
@ -514,20 +554,23 @@ class ChangeEmailCheckApi(Resource):
|
|||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
parser_reset = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("new_email", type=email, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/reset")
|
||||
class ChangeEmailResetApi(Resource):
|
||||
@api.expect(parser_reset)
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("new_email", type=email, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_reset.parse_args()
|
||||
|
||||
if AccountService.is_account_in_freeze(args["new_email"]):
|
||||
raise AccountInFreezeError()
|
||||
|
|
@ -555,12 +598,15 @@ class ChangeEmailResetApi(Resource):
|
|||
return updated_account
|
||||
|
||||
|
||||
parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/check-email-unique")
|
||||
class CheckEmailUnique(Resource):
|
||||
@api.expect(parser_check)
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_check.parse_args()
|
||||
if AccountService.is_account_in_freeze(args["email"]):
|
||||
raise AccountInFreezeError()
|
||||
if not AccountService.check_email_unique(args["email"]):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -31,11 +30,10 @@ class EndpointCreateApi(Resource):
|
|||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
|
|
@ -168,6 +166,7 @@ class EndpointDeleteApi(Resource):
|
|||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
|
@ -175,9 +174,6 @@ class EndpointDeleteApi(Resource):
|
|||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
return {
|
||||
|
|
@ -207,6 +203,7 @@ class EndpointUpdateApi(Resource):
|
|||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
|
@ -223,9 +220,6 @@ class EndpointUpdateApi(Resource):
|
|||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -252,6 +246,7 @@ class EndpointEnableApi(Resource):
|
|||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
|
@ -261,9 +256,6 @@ class EndpointEnableApi(Resource):
|
|||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
|
@ -284,6 +276,7 @@ class EndpointDisableApi(Resource):
|
|||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
|
@ -293,9 +286,6 @@ class EndpointDisableApi(Resource):
|
|||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse
|
|||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
EmailCodeError,
|
||||
|
|
@ -48,22 +48,25 @@ class MemberListApi(Resource):
|
|||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
||||
parser_invite = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("emails", type=list, required=True, location="json")
|
||||
.add_argument("role", type=str, required=True, default="admin", location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/invite-email")
|
||||
class MemberInviteEmailApi(Resource):
|
||||
"""Invite a new member by email."""
|
||||
|
||||
@api.expect(parser_invite)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("members")
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("emails", type=list, required=True, location="json")
|
||||
.add_argument("role", type=str, required=True, default="admin", location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_invite.parse_args()
|
||||
|
||||
invitee_emails = args["emails"]
|
||||
invitee_role = args["role"]
|
||||
|
|
@ -143,16 +146,19 @@ class MemberCancelInviteApi(Resource):
|
|||
}, 200
|
||||
|
||||
|
||||
parser_update = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
|
||||
class MemberUpdateRoleApi(Resource):
|
||||
"""Update member role."""
|
||||
|
||||
@api.expect(parser_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, member_id):
|
||||
parser = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_update.parse_args()
|
||||
new_role = args["role"]
|
||||
|
||||
if not TenantAccountRole.is_valid_role(new_role):
|
||||
|
|
@ -191,17 +197,20 @@ class DatasetOperatorMemberListApi(Resource):
|
|||
return {"result": "success", "accounts": members}, 200
|
||||
|
||||
|
||||
parser_send = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
|
||||
class SendOwnerTransferEmailApi(Resource):
|
||||
"""Send owner transfer email."""
|
||||
|
||||
@api.expect(parser_send)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_send.parse_args()
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
|
@ -229,19 +238,22 @@ class SendOwnerTransferEmailApi(Resource):
|
|||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
parser_owner = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/owner-transfer-check")
|
||||
class OwnerTransferCheckApi(Resource):
|
||||
@api.expect(parser_owner)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_owner.parse_args()
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
|
|
@ -276,17 +288,20 @@ class OwnerTransferCheckApi(Resource):
|
|||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
parser_owner_transfer = reqparse.RequestParser().add_argument(
|
||||
"token", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
|
||||
class OwnerTransfer(Resource):
|
||||
@api.expect(parser_owner_transfer)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self, member_id):
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"token", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_owner_transfer.parse_args()
|
||||
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
|
|
|||
|
|
@ -2,10 +2,9 @@ import io
|
|||
|
||||
from flask import send_file
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -14,9 +13,19 @@ from libs.login import current_account_with_tenant, login_required
|
|||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
parser_model = reqparse.RequestParser().add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers")
|
||||
class ModelProviderListApi(Resource):
|
||||
@api.expect(parser_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -24,15 +33,7 @@ class ModelProviderListApi(Resource):
|
|||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_model.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
|
||||
|
|
@ -40,8 +41,30 @@ class ModelProviderListApi(Resource):
|
|||
return jsonable_encoder({"data": provider_list})
|
||||
|
||||
|
||||
parser_cred = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
|
||||
)
|
||||
parser_post_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
parser_put_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
parser_delete_cred = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
class ModelProviderCredentialApi(Resource):
|
||||
@api.expect(parser_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -49,10 +72,7 @@ class ModelProviderCredentialApi(Resource):
|
|||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
# if credential_id is not provided, return current used credential
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_provider_credential(
|
||||
|
|
@ -61,20 +81,14 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
return {"credentials": credentials}
|
||||
|
||||
@api.expect(parser_post_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_post_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -90,21 +104,15 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@api.expect(parser_put_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_put_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -121,17 +129,14 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
return {"result": "success"}
|
||||
|
||||
@api.expect(parser_delete_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_delete_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credential(
|
||||
|
|
@ -141,19 +146,21 @@ class ModelProviderCredentialApi(Resource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
parser_switch = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
|
||||
class ModelProviderCredentialSwitchApi(Resource):
|
||||
@api.expect(parser_switch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_switch.parse_args()
|
||||
|
||||
service = ModelProviderService()
|
||||
service.switch_active_provider_credential(
|
||||
|
|
@ -164,17 +171,20 @@ class ModelProviderCredentialSwitchApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_validate = reqparse.RequestParser().add_argument(
|
||||
"credentials", type=dict, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@api.expect(parser_validate)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credentials", type=dict, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_validate.parse_args()
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
|
|
@ -218,27 +228,29 @@ class ModelProviderIconApi(Resource):
|
|||
return send_file(io.BytesIO(icon), mimetype=mimetype)
|
||||
|
||||
|
||||
parser_preferred = reqparse.RequestParser().add_argument(
|
||||
"preferred_provider_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=["system", "custom"],
|
||||
location="json",
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
|
||||
class PreferredProviderTypeUpdateApi(Resource):
|
||||
@api.expect(parser_preferred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"preferred_provider_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=["system", "custom"],
|
||||
location="json",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_preferred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.switch_preferred_provider(
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import logging
|
||||
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -16,23 +15,29 @@ from services.model_provider_service import ModelProviderService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
parser_get_default = reqparse.RequestParser().add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
parser_post_default = reqparse.RequestParser().add_argument(
|
||||
"model_settings", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/default-model")
|
||||
class DefaultModelApi(Resource):
|
||||
@api.expect(parser_get_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_get_default.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
||||
|
|
@ -41,19 +46,15 @@ class DefaultModelApi(Resource):
|
|||
|
||||
return jsonable_encoder({"data": default_model_entity})
|
||||
|
||||
@api.expect(parser_post_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"model_settings", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_post_default.parse_args()
|
||||
model_provider_service = ModelProviderService()
|
||||
model_settings = args["model_settings"]
|
||||
for model_setting in model_settings:
|
||||
|
|
@ -84,6 +85,35 @@ class DefaultModelApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_post_models = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
|
||||
)
|
||||
parser_delete_models = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
|
||||
class ModelProviderModelApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -97,32 +127,15 @@ class ModelProviderModelApi(Resource):
|
|||
|
||||
return jsonable_encoder({"data": models})
|
||||
|
||||
@api.expect(parser_post_models)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
# To save the model's load balance configs
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = parser_post_models.parse_args()
|
||||
|
||||
if args.get("config_from", "") == "custom-model":
|
||||
if not args.get("credential_id"):
|
||||
|
|
@ -160,28 +173,15 @@ class ModelProviderModelApi(Resource):
|
|||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@api.expect(parser_delete_models)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_delete_models.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model(
|
||||
|
|
@ -191,29 +191,76 @@ class ModelProviderModelApi(Resource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
parser_get_credentials = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
|
||||
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
parser_post_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
parser_put_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
parser_delete_cred = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
|
||||
class ModelProviderModelCredentialApi(Resource):
|
||||
@api.expect(parser_get_credentials)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="args",
|
||||
)
|
||||
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
|
||||
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_get_credentials.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
current_credential = model_provider_service.get_model_credential(
|
||||
|
|
@ -257,30 +304,15 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
}
|
||||
)
|
||||
|
||||
@api.expect(parser_post_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_post_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -304,31 +336,14 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@api.expect(parser_put_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_put_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -347,28 +362,14 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
|
||||
return {"result": "success"}
|
||||
|
||||
@api.expect(parser_delete_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_delete_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model_credential(
|
||||
|
|
@ -382,30 +383,32 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
parser_switch = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
|
||||
class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@api.expect(parser_switch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_switch.parse_args()
|
||||
|
||||
service = ModelProviderService()
|
||||
service.add_model_credential_to_model_list(
|
||||
|
|
@ -418,29 +421,32 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_model_enable_disable = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
|
||||
)
|
||||
class ModelProviderModelEnableApi(Resource):
|
||||
@api.expect(parser_model_enable_disable)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_model_enable_disable.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.enable_model(
|
||||
|
|
@ -454,25 +460,14 @@ class ModelProviderModelEnableApi(Resource):
|
|||
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
|
||||
)
|
||||
class ModelProviderModelDisableApi(Resource):
|
||||
@api.expect(parser_model_enable_disable)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_model_enable_disable.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.disable_model(
|
||||
|
|
@ -482,28 +477,31 @@ class ModelProviderModelDisableApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_validate = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
@api.expect(parser_validate)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_validate.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -530,16 +528,19 @@ class ModelProviderModelValidateApi(Resource):
|
|||
return response
|
||||
|
||||
|
||||
parser_parameter = reqparse.RequestParser().add_argument(
|
||||
"model", type=str, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
|
||||
class ModelProviderModelParameterRuleApi(Resource):
|
||||
@api.expect(parser_parameter)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"model", type=str, required=True, nullable=False, location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_parameter.parse_args()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ from flask_restx import Resource, reqparse
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -37,19 +37,22 @@ class PluginDebuggingKeyApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_list = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=False, location="args", default=1)
|
||||
.add_argument("page_size", type=int, required=False, location="args", default=256)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list")
|
||||
class PluginListApi(Resource):
|
||||
@api.expect(parser_list)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=False, location="args", default=1)
|
||||
.add_argument("page_size", type=int, required=False, location="args", default=256)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_list.parse_args()
|
||||
try:
|
||||
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
|
|
@ -58,14 +61,17 @@ class PluginListApi(Resource):
|
|||
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
|
||||
|
||||
|
||||
parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
|
||||
class PluginListLatestVersionsApi(Resource):
|
||||
@api.expect(parser_latest)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
req = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
args = parser_latest.parse_args()
|
||||
|
||||
try:
|
||||
versions = PluginService.list_latest_versions(args["plugin_ids"])
|
||||
|
|
@ -75,16 +81,19 @@ class PluginListLatestVersionsApi(Resource):
|
|||
return jsonable_encoder({"versions": versions})
|
||||
|
||||
|
||||
parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
|
||||
class PluginListInstallationsFromIdsApi(Resource):
|
||||
@api.expect(parser_ids)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_ids.parse_args()
|
||||
|
||||
try:
|
||||
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
|
||||
|
|
@ -94,16 +103,19 @@ class PluginListInstallationsFromIdsApi(Resource):
|
|||
return jsonable_encoder({"plugins": plugins})
|
||||
|
||||
|
||||
parser_icon = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tenant_id", type=str, required=True, location="args")
|
||||
.add_argument("filename", type=str, required=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/icon")
|
||||
class PluginIconApi(Resource):
|
||||
@api.expect(parser_icon)
|
||||
@setup_required
|
||||
def get(self):
|
||||
req = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tenant_id", type=str, required=True, location="args")
|
||||
.add_argument("filename", type=str, required=True, location="args")
|
||||
)
|
||||
args = req.parse_args()
|
||||
args = parser_icon.parse_args()
|
||||
|
||||
try:
|
||||
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
|
||||
|
|
@ -120,9 +132,11 @@ class PluginAssetApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
req.add_argument("file_name", type=str, required=True, location="args")
|
||||
req = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
.add_argument("file_name", type=str, required=True, location="args")
|
||||
)
|
||||
args = req.parse_args()
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
|
@ -157,8 +171,17 @@ class PluginUploadFromPkgApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_github = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upload/github")
|
||||
class PluginUploadFromGithubApi(Resource):
|
||||
@api.expect(parser_github)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -166,13 +189,7 @@ class PluginUploadFromGithubApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_github.parse_args()
|
||||
|
||||
try:
|
||||
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
|
||||
|
|
@ -206,19 +223,21 @@ class PluginUploadFromBundleApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_pkg = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/pkg")
|
||||
class PluginInstallFromPkgApi(Resource):
|
||||
@api.expect(parser_pkg)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_pkg.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
|
|
@ -233,8 +252,18 @@ class PluginInstallFromPkgApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_githubapi = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/github")
|
||||
class PluginInstallFromGithubApi(Resource):
|
||||
@api.expect(parser_githubapi)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -242,14 +271,7 @@ class PluginInstallFromGithubApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_githubapi.parse_args()
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_github(
|
||||
|
|
@ -265,8 +287,14 @@ class PluginInstallFromGithubApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_marketplace = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/install/marketplace")
|
||||
class PluginInstallFromMarketplaceApi(Resource):
|
||||
@api.expect(parser_marketplace)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -274,10 +302,7 @@ class PluginInstallFromMarketplaceApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_marketplace.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
|
|
@ -292,19 +317,21 @@ class PluginInstallFromMarketplaceApi(Resource):
|
|||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
parser_pkgapi = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
|
||||
class PluginFetchMarketplacePkgApi(Resource):
|
||||
@api.expect(parser_pkgapi)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_pkgapi.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -319,8 +346,14 @@ class PluginFetchMarketplacePkgApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_fetch = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
|
||||
class PluginFetchManifestApi(Resource):
|
||||
@api.expect(parser_fetch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -328,10 +361,7 @@ class PluginFetchManifestApi(Resource):
|
|||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_fetch.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -345,8 +375,16 @@ class PluginFetchManifestApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_tasks = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks")
|
||||
class PluginFetchInstallTasksApi(Resource):
|
||||
@api.expect(parser_tasks)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -354,12 +392,7 @@ class PluginFetchInstallTasksApi(Resource):
|
|||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, required=True, location="args")
|
||||
.add_argument("page_size", type=int, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_tasks.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -429,8 +462,16 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_marketplace_api = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
|
||||
class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@api.expect(parser_marketplace_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -438,12 +479,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_marketplace_api.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -455,8 +491,19 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_github_post = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/github")
|
||||
class PluginUpgradeFromGithubApi(Resource):
|
||||
@api.expect(parser_github_post)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -464,15 +511,7 @@ class PluginUpgradeFromGithubApi(Resource):
|
|||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
.add_argument("repo", type=str, required=True, location="json")
|
||||
.add_argument("version", type=str, required=True, location="json")
|
||||
.add_argument("package", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_github_post.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -489,15 +528,20 @@ class PluginUpgradeFromGithubApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_uninstall = reqparse.RequestParser().add_argument(
|
||||
"plugin_installation_id", type=str, required=True, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/uninstall")
|
||||
class PluginUninstallApi(Resource):
|
||||
@api.expect(parser_uninstall)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
req = reqparse.RequestParser().add_argument("plugin_installation_id", type=str, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
args = parser_uninstall.parse_args()
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
|
|
@ -507,8 +551,16 @@ class PluginUninstallApi(Resource):
|
|||
raise ValueError(e)
|
||||
|
||||
|
||||
parser_change_post = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("install_permission", type=str, required=True, location="json")
|
||||
.add_argument("debug_permission", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/permission/change")
|
||||
class PluginChangePermissionApi(Resource):
|
||||
@api.expect(parser_change_post)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -518,12 +570,7 @@ class PluginChangePermissionApi(Resource):
|
|||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
req = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("install_permission", type=str, required=True, location="json")
|
||||
.add_argument("debug_permission", type=str, required=True, location="json")
|
||||
)
|
||||
args = req.parse_args()
|
||||
args = parser_change_post.parse_args()
|
||||
|
||||
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
||||
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
||||
|
|
@ -558,29 +605,29 @@ class PluginFetchPermissionApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_dynamic = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
.add_argument("provider", type=str, required=True, location="args")
|
||||
.add_argument("action", type=str, required=True, location="args")
|
||||
.add_argument("parameter", type=str, required=True, location="args")
|
||||
.add_argument("credential_id", type=str, required=False, location="args")
|
||||
.add_argument("provider_type", type=str, required=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
|
||||
class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
@api.expect(parser_dynamic)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
# check if the user is admin or owner
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
.add_argument("provider", type=str, required=True, location="args")
|
||||
.add_argument("action", type=str, required=True, location="args")
|
||||
.add_argument("parameter", type=str, required=True, location="args")
|
||||
.add_argument("credential_id", type=str, required=False, location="args")
|
||||
.add_argument("provider_type", type=str, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_dynamic.parse_args()
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
|
|
@ -599,8 +646,16 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
|||
return jsonable_encoder({"options": options})
|
||||
|
||||
|
||||
parser_change = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("permission", type=dict, required=True, location="json")
|
||||
.add_argument("auto_upgrade", type=dict, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
||||
class PluginChangePreferencesApi(Resource):
|
||||
@api.expect(parser_change)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -609,12 +664,7 @@ class PluginChangePreferencesApi(Resource):
|
|||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
req = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("permission", type=dict, required=True, location="json")
|
||||
.add_argument("auto_upgrade", type=dict, required=True, location="json")
|
||||
)
|
||||
args = req.parse_args()
|
||||
args = parser_change.parse_args()
|
||||
|
||||
permission = args["permission"]
|
||||
|
||||
|
|
@ -694,8 +744,12 @@ class PluginFetchPreferencesApi(Resource):
|
|||
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
|
||||
|
||||
|
||||
parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
|
||||
class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@api.expect(parser_exclude)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -703,8 +757,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
|||
# exclude one single plugin
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
req = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
args = parser_exclude.parse_args()
|
||||
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
|
||||
|
||||
|
|
@ -716,9 +769,11 @@ class PluginReadmeApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
parser.add_argument("language", type=str, required=False, location="args")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
.add_argument("language", type=str, required=False, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -10,10 +10,11 @@ from sqlalchemy.orm import Session
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
|
|
@ -52,8 +53,19 @@ def is_valid_url(url: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
parser_tool = reqparse.RequestParser().add_argument(
|
||||
"type",
|
||||
type=str,
|
||||
choices=["builtin", "model", "api", "workflow", "mcp"],
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="args",
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-providers")
|
||||
class ToolProviderListApi(Resource):
|
||||
@api.expect(parser_tool)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -62,15 +74,7 @@ class ToolProviderListApi(Resource):
|
|||
|
||||
user_id = user.id
|
||||
|
||||
req = reqparse.RequestParser().add_argument(
|
||||
"type",
|
||||
type=str,
|
||||
choices=["builtin", "model", "api", "workflow", "mcp"],
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="args",
|
||||
)
|
||||
args = req.parse_args()
|
||||
args = parser_tool.parse_args()
|
||||
|
||||
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None))
|
||||
|
||||
|
|
@ -102,20 +106,22 @@ class ToolBuiltinProviderInfoApi(Resource):
|
|||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||
|
||||
|
||||
parser_delete = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@api.expect(parser_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
req = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = req.parse_args()
|
||||
args = parser_delete.parse_args()
|
||||
|
||||
return BuiltinToolManageService.delete_builtin_tool_provider(
|
||||
tenant_id,
|
||||
|
|
@ -124,8 +130,17 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_add = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
|
||||
class ToolBuiltinProviderAddApi(Resource):
|
||||
@api.expect(parser_add)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -134,13 +149,7 @@ class ToolBuiltinProviderAddApi(Resource):
|
|||
|
||||
user_id = user.id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_add.parse_args()
|
||||
|
||||
if args["type"] not in CredentialType.values():
|
||||
raise ValueError(f"Invalid credential type: {args['type']}")
|
||||
|
|
@ -155,27 +164,26 @@ class ToolBuiltinProviderAddApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@api.expect(parser_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_update.parse_args()
|
||||
|
||||
result = BuiltinToolManageService.update_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
|
|
@ -213,32 +221,32 @@ class ToolBuiltinProviderIconApi(Resource):
|
|||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
|
||||
parser_api_add = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
|
||||
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/add")
|
||||
class ToolApiProviderAddApi(Resource):
|
||||
@api.expect(parser_api_add)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
|
||||
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_api_add.parse_args()
|
||||
|
||||
return ApiToolManageService.create_api_tool_provider(
|
||||
user_id,
|
||||
|
|
@ -254,8 +262,12 @@ class ToolApiProviderAddApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/remote")
|
||||
class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@api.expect(parser_remote)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -264,9 +276,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
|||
|
||||
user_id = user.id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_remote.parse_args()
|
||||
|
||||
return ApiToolManageService.get_api_tool_provider_remote_schema(
|
||||
user_id,
|
||||
|
|
@ -275,8 +285,14 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_tools = reqparse.RequestParser().add_argument(
|
||||
"provider", type=str, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/tools")
|
||||
class ToolApiProviderListToolsApi(Resource):
|
||||
@api.expect(parser_tools)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -285,11 +301,7 @@ class ToolApiProviderListToolsApi(Resource):
|
|||
|
||||
user_id = user.id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"provider", type=str, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_tools.parse_args()
|
||||
|
||||
return jsonable_encoder(
|
||||
ApiToolManageService.list_api_tool_provider_tools(
|
||||
|
|
@ -300,33 +312,33 @@ class ToolApiProviderListToolsApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_api_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/update")
|
||||
class ToolApiProviderUpdateApi(Resource):
|
||||
@api.expect(parser_api_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_api_update.parse_args()
|
||||
|
||||
return ApiToolManageService.update_api_tool_provider(
|
||||
user_id,
|
||||
|
|
@ -343,24 +355,24 @@ class ToolApiProviderUpdateApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_api_delete = reqparse.RequestParser().add_argument(
|
||||
"provider", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/delete")
|
||||
class ToolApiProviderDeleteApi(Resource):
|
||||
@api.expect(parser_api_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"provider", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_api_delete.parse_args()
|
||||
|
||||
return ApiToolManageService.delete_api_tool_provider(
|
||||
user_id,
|
||||
|
|
@ -369,8 +381,12 @@ class ToolApiProviderDeleteApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_get = reqparse.RequestParser().add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/get")
|
||||
class ToolApiProviderGetApi(Resource):
|
||||
@api.expect(parser_get)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -379,11 +395,7 @@ class ToolApiProviderGetApi(Resource):
|
|||
|
||||
user_id = user.id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"provider", type=str, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_get.parse_args()
|
||||
|
||||
return ApiToolManageService.get_api_tool_provider(
|
||||
user_id,
|
||||
|
|
@ -407,40 +419,44 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_schema = reqparse.RequestParser().add_argument(
|
||||
"schema", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/schema")
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@api.expect(parser_schema)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"schema", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_schema.parse_args()
|
||||
|
||||
return ApiToolManageService.parser_api_schema(
|
||||
schema=args["schema"],
|
||||
)
|
||||
|
||||
|
||||
parser_pre = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
|
||||
class ToolApiProviderPreviousTestApi(Resource):
|
||||
@api.expect(parser_pre)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_pre.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return ApiToolManageService.test_api_tool_preview(
|
||||
current_tenant_id,
|
||||
|
|
@ -453,32 +469,32 @@ class ToolApiProviderPreviousTestApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_create = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
||||
.add_argument("label", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
|
||||
class ToolWorkflowProviderCreateApi(Resource):
|
||||
@api.expect(parser_create)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
reqparser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
||||
.add_argument("label", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
args = reqparser.parse_args()
|
||||
args = parser_create.parse_args()
|
||||
|
||||
return WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=user_id,
|
||||
|
|
@ -494,32 +510,31 @@ class ToolWorkflowProviderCreateApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_workflow_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
||||
.add_argument("label", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
|
||||
class ToolWorkflowProviderUpdateApi(Resource):
|
||||
@api.expect(parser_workflow_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
reqparser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
||||
.add_argument("label", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
args = reqparser.parse_args()
|
||||
args = parser_workflow_update.parse_args()
|
||||
|
||||
if not args["workflow_tool_id"]:
|
||||
raise ValueError("incorrect workflow_tool_id")
|
||||
|
|
@ -538,24 +553,24 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_workflow_delete = reqparse.RequestParser().add_argument(
|
||||
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
|
||||
class ToolWorkflowProviderDeleteApi(Resource):
|
||||
@api.expect(parser_workflow_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
reqparser = reqparse.RequestParser().add_argument(
|
||||
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
args = reqparser.parse_args()
|
||||
args = parser_workflow_delete.parse_args()
|
||||
|
||||
return WorkflowToolManageService.delete_workflow_tool(
|
||||
user_id,
|
||||
|
|
@ -564,8 +579,16 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_wf_get = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
|
||||
class ToolWorkflowProviderGetApi(Resource):
|
||||
@api.expect(parser_wf_get)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -574,13 +597,7 @@ class ToolWorkflowProviderGetApi(Resource):
|
|||
|
||||
user_id = user.id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_wf_get.parse_args()
|
||||
|
||||
if args.get("workflow_tool_id"):
|
||||
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
|
||||
|
|
@ -600,8 +617,14 @@ class ToolWorkflowProviderGetApi(Resource):
|
|||
return jsonable_encoder(tool)
|
||||
|
||||
|
||||
parser_wf_tools = reqparse.RequestParser().add_argument(
|
||||
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
|
||||
class ToolWorkflowProviderListToolApi(Resource):
|
||||
@api.expect(parser_wf_tools)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -610,11 +633,7 @@ class ToolWorkflowProviderListToolApi(Resource):
|
|||
|
||||
user_id = user.id
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser_wf_tools.parse_args()
|
||||
|
||||
return jsonable_encoder(
|
||||
WorkflowToolManageService.list_single_workflow_tools(
|
||||
|
|
@ -699,18 +718,15 @@ class ToolLabelsApi(Resource):
|
|||
class ToolPluginOAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
|
||||
# todo check permission
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("no oauth available client config found for this tool provider")
|
||||
|
|
@ -790,37 +806,43 @@ class ToolOAuthCallback(Resource):
|
|||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
parser_default_cred = reqparse.RequestParser().add_argument(
|
||||
"id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
|
||||
class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@api.expect(parser_default_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_default_cred.parse_args()
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
|
||||
)
|
||||
|
||||
|
||||
parser_custom = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
|
||||
class ToolOAuthCustomClient(Resource):
|
||||
@api.expect(parser_custom)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
def post(self, provider: str):
|
||||
args = parser_custom.parse_args()
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return BuiltinToolManageService.save_custom_oauth_client_params(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -878,25 +900,44 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
parser_mcp = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
parser_mcp_put = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
parser_mcp_delete = reqparse.RequestParser().add_argument(
|
||||
"provider_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp")
|
||||
class ToolProviderMCPApi(Resource):
|
||||
@api.expect(parser_mcp)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_mcp.parse_args()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
# Parse and validate models
|
||||
|
|
@ -921,24 +962,12 @@ class ToolProviderMCPApi(Resource):
|
|||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@api.expect(parser_mcp_put)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_mcp_put.parse_args()
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
|
@ -972,14 +1001,12 @@ class ToolProviderMCPApi(Resource):
|
|||
)
|
||||
return {"result": "success"}
|
||||
|
||||
@api.expect(parser_mcp_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self):
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"provider_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_mcp_delete.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session, session.begin():
|
||||
|
|
@ -988,18 +1015,21 @@ class ToolProviderMCPApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
parser_auth = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
|
||||
class ToolMCPAuthApi(Resource):
|
||||
@api.expect(parser_auth)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_auth.parse_args()
|
||||
provider_id = args["provider_id"]
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
|
|
@ -1035,7 +1065,13 @@ class ToolMCPAuthApi(Resource):
|
|||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
try:
|
||||
auth_result = auth(provider_entity, args.get("authorization_code"))
|
||||
# Pass the extracted OAuth metadata hints to auth()
|
||||
auth_result = auth(
|
||||
provider_entity,
|
||||
args.get("authorization_code"),
|
||||
resource_metadata_url=e.resource_metadata_url,
|
||||
scope_hint=e.scope_hint,
|
||||
)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
|
|
@ -1045,7 +1081,7 @@ class ToolMCPAuthApi(Resource):
|
|||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except MCPError as e:
|
||||
except (MCPError, ValueError) as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
|
@ -1097,15 +1133,18 @@ class ToolMCPUpdateApi(Resource):
|
|||
return jsonable_encoder(tools)
|
||||
|
||||
|
||||
parser_cb = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("code", type=str, required=True, nullable=False, location="args")
|
||||
.add_argument("state", type=str, required=True, nullable=False, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/mcp/oauth/callback")
|
||||
class ToolMCPCallbackApi(Resource):
|
||||
@api.expect(parser_cb)
|
||||
def get(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("code", type=str, required=True, nullable=False, location="args")
|
||||
.add_argument("state", type=str, required=True, nullable=False, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parser_cb.parse_args()
|
||||
state_key = args["state"]
|
||||
authorization_code = args["code"]
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from werkzeug.exceptions import BadRequest, Forbidden
|
|||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
|
|
@ -67,14 +67,12 @@ class TriggerProviderInfoApi(Resource):
|
|||
class TriggerSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -92,17 +90,16 @@ class TriggerSubscriptionListApi(Resource):
|
|||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_type", type=str, required=False, nullable=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -133,18 +130,17 @@ class TriggerSubscriptionBuilderGetApi(Resource):
|
|||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Verify a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -173,15 +169,17 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -223,24 +221,23 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
|
|||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
# Use atomic update_and_build to prevent race conditions
|
||||
|
|
@ -264,14 +261,12 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||
class TriggerSubscriptionDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, subscription_id: str):
|
||||
"""Delete a subscription instance"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -446,14 +441,12 @@ class TriggerOAuthCallbackApi(Resource):
|
|||
class TriggerOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
|
@ -493,18 +486,18 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -524,14 +517,12 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from controllers.common.errors import (
|
|||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import (
|
||||
|
|
@ -128,7 +128,7 @@ class TenantApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(tenant_fields)
|
||||
def get(self):
|
||||
def post(self):
|
||||
if request.path == "/info":
|
||||
logger.warning("Deprecated URL /info was used.")
|
||||
|
||||
|
|
@ -150,15 +150,18 @@ class TenantApi(Resource):
|
|||
return WorkspaceService.get_tenant_info(tenant), 200
|
||||
|
||||
|
||||
parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/switch")
|
||||
class SwitchWorkspaceApi(Resource):
|
||||
@api.expect(parser_switch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_switch.parse_args()
|
||||
|
||||
# check if tenant_id is valid, 403 if not
|
||||
try:
|
||||
|
|
@ -242,16 +245,19 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
return {"id": upload_file.id}, 201
|
||||
|
||||
|
||||
parser_info = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/info")
|
||||
class WorkspaceInfoApi(Resource):
|
||||
@api.expect(parser_info)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
# Change workspace name
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_info.parse_args()
|
||||
|
||||
if not current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
|
|
|
|||
|
|
@ -315,3 +315,19 @@ def edit_permission_required(f: Callable[P, R]):
|
|||
return f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
|
||||
def is_admin_or_owner_required(f: Callable[P, R]):
|
||||
@wraps(f)
|
||||
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
user = current_user._get_current_object()
|
||||
if not isinstance(user, Account) or not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
|
|
|||
|
|
@ -3,14 +3,12 @@ from typing import Literal
|
|||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx.api import HTTPStatus
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_fields, build_annotation_model
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
|
@ -161,14 +159,10 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@edit_permission_required
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
|
||||
def put(self, app_model: App, annotation_id):
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
annotation_id = str(annotation_id)
|
||||
args = annotation_create_parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
return annotation
|
||||
|
|
@ -185,13 +179,8 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def delete(self, app_model: App, annotation_id):
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, annotation_id: str):
|
||||
"""Delete an annotation."""
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
|
||||
return {"result": "success"}, 204
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from flask_restx import marshal, reqparse
|
|||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
from controllers.service_api.wraps import (
|
||||
|
|
@ -619,11 +620,9 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
@edit_permission_required
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
args = tag_delete_parser.parse_args()
|
||||
TagService.delete_tag(args["tag_id"])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
import json
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
|
@ -31,7 +34,7 @@ from fields.document_fields import document_fields, document_status_fields
|
|||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from services.file_service import FileService
|
||||
|
||||
# Define parsers for document operations
|
||||
|
|
@ -51,15 +54,26 @@ document_text_create_parser = (
|
|||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
document_text_update_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("text", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class DocumentTextUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
text: str | None = None
|
||||
process_rule: ProcessRule | None = None
|
||||
doc_form: str = "text_model"
|
||||
doc_language: str = "English"
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_text_and_name(self) -> Self:
|
||||
if self.text is not None and self.name is None:
|
||||
raise ValueError("name is required when text is provided")
|
||||
return self
|
||||
|
||||
|
||||
for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
|
||||
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
|
|
@ -160,7 +174,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@service_api_ns.expect(document_text_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True)
|
||||
@service_api_ns.doc("update_document_by_text")
|
||||
@service_api_ns.doc(description="Update an existing document by providing text content")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
|
|
@ -173,12 +187,10 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text."""
|
||||
args = document_text_update_parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
|
@ -198,11 +210,9 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if args["text"]:
|
||||
if args.get("text"):
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both text and name must be strings.")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
|
|
@ -456,12 +466,16 @@ class DocumentListApi(DatasetApiResource):
|
|||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
status = request.args.get("status", default=None, type=str)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||
|
||||
if status:
|
||||
query = DocumentService.apply_display_status_filter(query, status)
|
||||
|
||||
if search:
|
||||
search = f"%{search}%"
|
||||
query = query.where(Document.name.like(search))
|
||||
|
|
|
|||
|
|
@ -88,12 +88,6 @@ class AudioApi(WebApiResource):
|
|||
|
||||
@web_ns.route("/text-to-audio")
|
||||
class TextApi(WebApiResource):
|
||||
text_to_audio_response_fields = {
|
||||
"audio_url": fields.String,
|
||||
"duration": fields.Float,
|
||||
}
|
||||
|
||||
@marshal_with(text_to_audio_response_fields)
|
||||
@web_ns.doc("Text to Audio")
|
||||
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
|
||||
@web_ns.doc(
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ class LoginStatusApi(Resource):
|
|||
)
|
||||
def get(self):
|
||||
app_code = request.args.get("app_code")
|
||||
user_id = request.args.get("user_id")
|
||||
token = extract_webapp_access_token(request)
|
||||
if not app_code:
|
||||
return {
|
||||
|
|
@ -103,7 +104,7 @@ class LoginStatusApi(Resource):
|
|||
user_logged_in = False
|
||||
|
||||
try:
|
||||
_ = decode_jwt_token(app_code=app_code)
|
||||
_ = decode_jwt_token(app_code=app_code, user_id=user_id)
|
||||
app_logged_in = True
|
||||
except Exception:
|
||||
app_logged_in = False
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None =
|
|||
return decorator
|
||||
|
||||
|
||||
def decode_jwt_token(app_code: str | None = None):
|
||||
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
|
||||
system_features = FeatureService.get_system_features()
|
||||
if not app_code:
|
||||
app_code = str(request.headers.get(HEADER_NAME_APP_CODE))
|
||||
|
|
@ -63,6 +63,10 @@ def decode_jwt_token(app_code: str | None = None):
|
|||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
# Validate user_id against end_user's session_id if provided
|
||||
if user_id is not None and end_user.session_id != user_id:
|
||||
raise Unauthorized("Authentication has expired.")
|
||||
|
||||
# for enterprise webapp auth
|
||||
app_web_auth_enabled = False
|
||||
webapp_settings = None
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
datasource_type=datasource_type,
|
||||
datasource_info=json.dumps(datasource_info),
|
||||
datasource_node_id=start_node_id,
|
||||
input_data=inputs,
|
||||
input_data=dict(inputs),
|
||||
pipeline_id=pipeline.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -145,7 +145,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# for trigger debug run, not prepare user inputs
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
# trigger shouldn't prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
|
|
|
|||
|
|
@ -644,14 +644,15 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
if not workflow_run_id:
|
||||
return
|
||||
|
||||
workflow_app_log = WorkflowAppLog()
|
||||
workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
|
||||
workflow_app_log.workflow_id = self._workflow.id
|
||||
workflow_app_log.workflow_run_id = workflow_run_id
|
||||
workflow_app_log.created_from = created_from.value
|
||||
workflow_app_log.created_by_role = self._created_by_role
|
||||
workflow_app_log.created_by = self._user_id
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
created_from=created_from.value,
|
||||
created_by_role=self._created_by_role,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
|
||||
session.add(workflow_app_log)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -1,14 +1,10 @@
|
|||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Import InvokeFrom locally to avoid circular import
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
|
||||
class DatasourceRuntime(BaseModel):
|
||||
"""
|
||||
|
|
@ -17,7 +13,7 @@ class DatasourceRuntime(BaseModel):
|
|||
|
||||
tenant_id: str
|
||||
datasource_id: str | None = None
|
||||
invoke_from: Optional["InvokeFrom"] = None
|
||||
invoke_from: InvokeFrom | None = None
|
||||
datasource_invoke_from: DatasourceInvokeFrom | None = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ import secrets
|
|||
import urllib.parse
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from httpx import ConnectError, HTTPStatusError, RequestError
|
||||
import httpx
|
||||
from httpx import RequestError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
||||
|
|
@ -20,6 +21,7 @@ from core.mcp.types import (
|
|||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthTokens,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
|
@ -39,6 +41,131 @@ def generate_pkce_challenge() -> tuple[str, str]:
|
|||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
def build_protected_resource_metadata_discovery_urls(
|
||||
www_auth_resource_metadata_url: str | None, server_url: str
|
||||
) -> list[str]:
|
||||
"""
|
||||
Build a list of URLs to try for Protected Resource Metadata discovery.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
"""
|
||||
urls = []
|
||||
|
||||
# First priority: URL from WWW-Authenticate header
|
||||
if www_auth_resource_metadata_url:
|
||||
urls.append(www_auth_resource_metadata_url)
|
||||
|
||||
# Fallback: construct from server URL
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
||||
if fallback_url not in urls:
|
||||
urls.append(fallback_url)
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
|
||||
"""
|
||||
Build a list of URLs to try for OAuth Authorization Server Metadata discovery.
|
||||
|
||||
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
|
||||
|
||||
Per RFC 8414 section 3:
|
||||
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
|
||||
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
|
||||
|
||||
Example:
|
||||
- issuer: https://example.com/oauth
|
||||
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
|
||||
"""
|
||||
urls = []
|
||||
base_url = auth_server_url or server_url
|
||||
|
||||
parsed = urlparse(base_url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/") # Remove trailing slash
|
||||
|
||||
# Try OpenID Connect discovery first (more common)
|
||||
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
|
||||
|
||||
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
|
||||
# Include the path component if present in the issuer URL
|
||||
if path:
|
||||
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
|
||||
else:
|
||||
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def discover_protected_resource_metadata(
|
||||
prm_url: str | None, server_url: str, protocol_version: str | None = None
|
||||
) -> ProtectedResourceMetadata | None:
|
||||
"""Discover OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
|
||||
urls = build_protected_resource_metadata_discovery_urls(prm_url, server_url)
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
return ProtectedResourceMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def discover_oauth_authorization_server_metadata(
|
||||
auth_server_url: str | None, server_url: str, protocol_version: str | None = None
|
||||
) -> OAuthMetadata | None:
|
||||
"""Discover OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
|
||||
urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url)
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_effective_scope(
|
||||
scope_from_www_auth: str | None,
|
||||
prm: ProtectedResourceMetadata | None,
|
||||
asm: OAuthMetadata | None,
|
||||
client_scope: str | None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Determine effective scope using priority-based selection strategy.
|
||||
|
||||
Priority order:
|
||||
1. WWW-Authenticate header scope (server explicit requirement)
|
||||
2. Protected Resource Metadata scopes
|
||||
3. OAuth Authorization Server Metadata scopes
|
||||
4. Client configured scope
|
||||
"""
|
||||
if scope_from_www_auth:
|
||||
return scope_from_www_auth
|
||||
|
||||
if prm and prm.scopes_supported:
|
||||
return " ".join(prm.scopes_supported)
|
||||
|
||||
if asm and asm.scopes_supported:
|
||||
return " ".join(asm.scopes_supported)
|
||||
|
||||
return client_scope
|
||||
|
||||
|
||||
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
|
||||
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
|
||||
# Generate a secure random state key
|
||||
|
|
@ -121,42 +248,36 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
|||
return False, ""
|
||||
|
||||
|
||||
def discover_oauth_metadata(server_url: str, protocol_version: str | None = None) -> OAuthMetadata | None:
|
||||
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
|
||||
# First check if the server supports OAuth 2.0 Resource Discovery
|
||||
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
|
||||
if support_resource_discovery:
|
||||
# The oauth_discovery_url is the authorization server base URL
|
||||
# Try OpenID Connect discovery first (more common), then OAuth 2.0
|
||||
urls_to_try = [
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
|
||||
]
|
||||
else:
|
||||
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
|
||||
def discover_oauth_metadata(
|
||||
server_url: str,
|
||||
resource_metadata_url: str | None = None,
|
||||
scope_hint: str | None = None,
|
||||
protocol_version: str | None = None,
|
||||
) -> tuple[OAuthMetadata | None, ProtectedResourceMetadata | None, str | None]:
|
||||
"""
|
||||
Discover OAuth metadata using RFC 8414/9470 standards.
|
||||
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
Args:
|
||||
server_url: The MCP server URL
|
||||
resource_metadata_url: Protected Resource Metadata URL from WWW-Authenticate header
|
||||
scope_hint: Scope hint from WWW-Authenticate header
|
||||
protocol_version: MCP protocol version
|
||||
|
||||
for url in urls_to_try:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
continue
|
||||
if not response.is_success:
|
||||
response.raise_for_status()
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
except (RequestError, HTTPStatusError) as e:
|
||||
if isinstance(e, ConnectError):
|
||||
response = ssrf_proxy.get(url)
|
||||
if response.status_code == 404:
|
||||
continue # Try next URL
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
# For other errors, try next URL
|
||||
continue
|
||||
Returns:
|
||||
(oauth_metadata, protected_resource_metadata, scope_hint)
|
||||
"""
|
||||
# Discover Protected Resource Metadata
|
||||
prm = discover_protected_resource_metadata(resource_metadata_url, server_url, protocol_version)
|
||||
|
||||
return None # No metadata found
|
||||
# Get authorization server URL from PRM or use server URL
|
||||
auth_server_url = None
|
||||
if prm and prm.authorization_servers:
|
||||
auth_server_url = prm.authorization_servers[0]
|
||||
|
||||
# Discover OAuth Authorization Server Metadata
|
||||
asm = discover_oauth_authorization_server_metadata(auth_server_url, server_url, protocol_version)
|
||||
|
||||
return asm, prm, scope_hint
|
||||
|
||||
|
||||
def start_authorization(
|
||||
|
|
@ -166,6 +287,7 @@ def start_authorization(
|
|||
redirect_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
scope: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Begins the authorization flow with secure Redis state storage."""
|
||||
response_type = "code"
|
||||
|
|
@ -175,13 +297,6 @@ def start_authorization(
|
|||
authorization_url = metadata.authorization_endpoint
|
||||
if response_type not in metadata.response_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
|
||||
if (
|
||||
not metadata.code_challenge_methods_supported
|
||||
or code_challenge_method not in metadata.code_challenge_methods_supported
|
||||
):
|
||||
raise ValueError(
|
||||
f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
|
||||
)
|
||||
else:
|
||||
authorization_url = urljoin(server_url, "/authorize")
|
||||
|
||||
|
|
@ -210,10 +325,49 @@ def start_authorization(
|
|||
"state": state_key,
|
||||
}
|
||||
|
||||
# Add scope if provided
|
||||
if scope:
|
||||
params["scope"] = scope
|
||||
|
||||
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
return authorization_url, code_verifier
|
||||
|
||||
|
||||
def _parse_token_response(response: httpx.Response) -> OAuthTokens:
|
||||
"""
|
||||
Parse OAuth token response supporting both JSON and form-urlencoded formats.
|
||||
|
||||
Per RFC 6749 Section 5.1, the standard format is JSON.
|
||||
However, some legacy OAuth providers (e.g., early GitHub OAuth Apps) return
|
||||
application/x-www-form-urlencoded format for backwards compatibility.
|
||||
|
||||
Args:
|
||||
response: The HTTP response from token endpoint
|
||||
|
||||
Returns:
|
||||
Parsed OAuth tokens
|
||||
|
||||
Raises:
|
||||
ValueError: If response cannot be parsed
|
||||
"""
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
|
||||
if "application/json" in content_type:
|
||||
# Standard OAuth 2.0 JSON response (RFC 6749)
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
elif "application/x-www-form-urlencoded" in content_type:
|
||||
# Legacy form-urlencoded response (non-standard but used by some providers)
|
||||
token_data = dict(urllib.parse.parse_qsl(response.text))
|
||||
return OAuthTokens.model_validate(token_data)
|
||||
else:
|
||||
# No content-type or unknown - try JSON first, fallback to form-urlencoded
|
||||
try:
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
except (ValidationError, json.JSONDecodeError):
|
||||
token_data = dict(urllib.parse.parse_qsl(response.text))
|
||||
return OAuthTokens.model_validate(token_data)
|
||||
|
||||
|
||||
def exchange_authorization(
|
||||
server_url: str,
|
||||
metadata: OAuthMetadata | None,
|
||||
|
|
@ -246,7 +400,7 @@ def exchange_authorization(
|
|||
response = ssrf_proxy.post(token_url, data=params)
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
return _parse_token_response(response)
|
||||
|
||||
|
||||
def refresh_authorization(
|
||||
|
|
@ -279,7 +433,7 @@ def refresh_authorization(
|
|||
raise MCPRefreshTokenError(e) from e
|
||||
if not response.is_success:
|
||||
raise MCPRefreshTokenError(response.text)
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
return _parse_token_response(response)
|
||||
|
||||
|
||||
def client_credentials_flow(
|
||||
|
|
@ -322,7 +476,7 @@ def client_credentials_flow(
|
|||
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
|
||||
)
|
||||
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
return _parse_token_response(response)
|
||||
|
||||
|
||||
def register_client(
|
||||
|
|
@ -352,6 +506,8 @@ def auth(
|
|||
provider: MCPProviderEntity,
|
||||
authorization_code: str | None = None,
|
||||
state_param: str | None = None,
|
||||
resource_metadata_url: str | None = None,
|
||||
scope_hint: str | None = None,
|
||||
) -> AuthResult:
|
||||
"""
|
||||
Orchestrates the full auth flow with a server using secure Redis state storage.
|
||||
|
|
@ -363,18 +519,26 @@ def auth(
|
|||
provider: The MCP provider entity
|
||||
authorization_code: Optional authorization code from OAuth callback
|
||||
state_param: Optional state parameter from OAuth callback
|
||||
resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate
|
||||
scope_hint: Optional scope hint from WWW-Authenticate header
|
||||
|
||||
Returns:
|
||||
AuthResult containing actions to be performed and response data
|
||||
"""
|
||||
actions: list[AuthAction] = []
|
||||
server_url = provider.decrypt_server_url()
|
||||
server_metadata = discover_oauth_metadata(server_url)
|
||||
|
||||
# Discover OAuth metadata using RFC 8414/9470 standards
|
||||
server_metadata, prm, scope_from_www_auth = discover_oauth_metadata(
|
||||
server_url, resource_metadata_url, scope_hint, LATEST_PROTOCOL_VERSION
|
||||
)
|
||||
|
||||
client_metadata = provider.client_metadata
|
||||
provider_id = provider.id
|
||||
tenant_id = provider.tenant_id
|
||||
client_information = provider.retrieve_client_information()
|
||||
redirect_url = provider.redirect_url
|
||||
credentials = provider.decrypt_credentials()
|
||||
|
||||
# Determine grant type based on server metadata
|
||||
if not server_metadata:
|
||||
|
|
@ -392,8 +556,8 @@ def auth(
|
|||
else:
|
||||
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
# Get stored credentials
|
||||
credentials = provider.decrypt_credentials()
|
||||
# Determine effective scope using priority-based strategy
|
||||
effective_scope = get_effective_scope(scope_from_www_auth, prm, server_metadata, credentials.get("scope"))
|
||||
|
||||
if not client_information:
|
||||
if authorization_code is not None:
|
||||
|
|
@ -425,12 +589,11 @@ def auth(
|
|||
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||
# Direct token request without user interaction
|
||||
try:
|
||||
scope = credentials.get("scope")
|
||||
tokens = client_credentials_flow(
|
||||
server_url,
|
||||
server_metadata,
|
||||
client_information,
|
||||
scope,
|
||||
effective_scope,
|
||||
)
|
||||
|
||||
# Return action to save tokens and grant type
|
||||
|
|
@ -526,6 +689,7 @@ def auth(
|
|||
redirect_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
effective_scope,
|
||||
)
|
||||
|
||||
# Return action to save code verifier
|
||||
|
|
|
|||
|
|
@ -90,7 +90,13 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||
mcp_service = MCPToolManageService(session=session)
|
||||
|
||||
# Perform authentication using the service's auth method
|
||||
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
|
||||
# Extract OAuth metadata hints from the error
|
||||
mcp_service.auth_with_actions(
|
||||
self.provider_entity,
|
||||
self.authorization_code,
|
||||
resource_metadata_url=error.resource_metadata_url,
|
||||
scope_hint=error.scope_hint,
|
||||
)
|
||||
|
||||
# Retrieve new tokens
|
||||
self.provider_entity = mcp_service.get_provider_entity(
|
||||
|
|
|
|||
|
|
@ -290,7 +290,7 @@ def sse_client(
|
|||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPAuthError(response=exc.response)
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
|
|
|
|||
|
|
@ -138,6 +138,10 @@ class StreamableHTTPTransport:
|
|||
) -> bool:
|
||||
"""Handle an SSE event, returning True if the response is complete."""
|
||||
if sse.event == "message":
|
||||
# ping event send by server will be recognized as a message event with empty data by httpx-sse's SSEDecoder
|
||||
if not sse.data.strip():
|
||||
return False
|
||||
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug("SSE message: %s", message)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,10 @@
|
|||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
|
||||
|
||||
class MCPError(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -7,7 +14,49 @@ class MCPConnectionError(MCPError):
|
|||
|
||||
|
||||
class MCPAuthError(MCPConnectionError):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
message: str | None = None,
|
||||
response: "httpx.Response | None" = None,
|
||||
www_authenticate_header: str | None = None,
|
||||
):
|
||||
"""
|
||||
MCP Authentication Error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
response: HTTP response object (will extract WWW-Authenticate header if provided)
|
||||
www_authenticate_header: Pre-extracted WWW-Authenticate header value
|
||||
"""
|
||||
super().__init__(message or "Authentication failed")
|
||||
|
||||
# Extract OAuth metadata hints from WWW-Authenticate header
|
||||
if response is not None:
|
||||
www_authenticate_header = response.headers.get("WWW-Authenticate")
|
||||
|
||||
self.resource_metadata_url: str | None = None
|
||||
self.scope_hint: str | None = None
|
||||
|
||||
if www_authenticate_header:
|
||||
self.resource_metadata_url = self._extract_field(www_authenticate_header, "resource_metadata")
|
||||
self.scope_hint = self._extract_field(www_authenticate_header, "scope")
|
||||
|
||||
@staticmethod
|
||||
def _extract_field(www_auth: str, field_name: str) -> str | None:
|
||||
"""Extract a specific field from the WWW-Authenticate header."""
|
||||
# Pattern to match field="value" or field=value
|
||||
pattern = rf'{field_name}="([^"]*)"'
|
||||
match = re.search(pattern, www_auth)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Try without quotes
|
||||
pattern = rf"{field_name}=([^\s,]+)"
|
||||
match = re.search(pattern, www_auth)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MCPRefreshTokenError(MCPError):
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ class BaseSession(
|
|||
messages when entered.
|
||||
"""
|
||||
|
||||
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
|
||||
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError]]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
_receive_request_type: type[ReceiveRequestT]
|
||||
|
|
@ -230,7 +230,7 @@ class BaseSession(
|
|||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
|
||||
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
|
||||
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError] = queue.Queue()
|
||||
self._response_streams[request_id] = response_queue
|
||||
|
||||
try:
|
||||
|
|
@ -261,11 +261,17 @@ class BaseSession(
|
|||
message="No response received",
|
||||
)
|
||||
)
|
||||
elif isinstance(response_or_error, HTTPStatusError):
|
||||
# HTTPStatusError from streamable_client with preserved response object
|
||||
if response_or_error.response.status_code == 401:
|
||||
raise MCPAuthError(response=response_or_error.response)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.response.status_code, message=str(response_or_error))
|
||||
)
|
||||
elif isinstance(response_or_error, JSONRPCError):
|
||||
if response_or_error.error.code == 401:
|
||||
raise MCPAuthError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
)
|
||||
raise MCPAuthError(message=response_or_error.error.message)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
|
|
@ -327,13 +333,17 @@ class BaseSession(
|
|||
if isinstance(message, HTTPStatusError):
|
||||
response_queue = self._response_streams.get(self._request_id - 1)
|
||||
if response_queue is not None:
|
||||
response_queue.put(
|
||||
JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=self._request_id - 1,
|
||||
error=ErrorData(code=message.response.status_code, message=message.args[0]),
|
||||
# For 401 errors, pass the HTTPStatusError directly to preserve response object
|
||||
if message.response.status_code == 401:
|
||||
response_queue.put(message)
|
||||
else:
|
||||
response_queue.put(
|
||||
JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=self._request_id - 1,
|
||||
error=ErrorData(code=message.response.status_code, message=message.args[0]),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
|
||||
elif isinstance(message, Exception):
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ for reference.
|
|||
not separate types in the schema.
|
||||
"""
|
||||
# Client support both version, not support 2025-06-18 yet.
|
||||
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
||||
LATEST_PROTOCOL_VERSION = "2025-06-18"
|
||||
# Server support 2024-11-05 to allow claude to use.
|
||||
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
|
||||
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
|
||||
|
|
@ -1330,3 +1330,13 @@ class OAuthMetadata(BaseModel):
|
|||
response_types_supported: list[str]
|
||||
grant_types_supported: list[str] | None = None
|
||||
code_challenge_methods_supported: list[str] | None = None
|
||||
scopes_supported: list[str] | None = None
|
||||
|
||||
|
||||
class ProtectedResourceMetadata(BaseModel):
|
||||
"""OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
|
||||
|
||||
resource: str | None = None
|
||||
authorization_servers: list[str]
|
||||
scopes_supported: list[str] | None = None
|
||||
bearer_methods_supported: list[str] | None = None
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class OpenAIModeration(Moderation):
|
|||
text = "\n".join(str(inputs.values()))
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable"
|
||||
tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest"
|
||||
)
|
||||
|
||||
openai_moderation = model_instance.invoke_moderation(text=text)
|
||||
|
|
|
|||
|
|
@ -274,6 +274,8 @@ class OpsTraceManager:
|
|||
raise ValueError("App not found")
|
||||
|
||||
tenant_id = app.tenant_id
|
||||
if trace_config_data.tracing_config is None:
|
||||
raise ValueError("Tracing config cannot be None.")
|
||||
decrypt_tracing_config = cls.decrypt_tracing_config(
|
||||
tenant_id, tracing_provider, trace_config_data.tracing_config
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,20 @@
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
import wandb
|
||||
import weave
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from weave.trace_server.trace_server_interface import (
|
||||
CallEndReq,
|
||||
CallStartReq,
|
||||
EndedCallSchemaForInsert,
|
||||
StartedCallSchemaForInsert,
|
||||
SummaryInsertMap,
|
||||
TraceStatus,
|
||||
)
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
|
|
@ -57,6 +65,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.calls: dict[str, Any] = {}
|
||||
self.project_id = f"{self.weave_client.entity}/{self.weave_client.project}"
|
||||
|
||||
def get_project_url(
|
||||
self,
|
||||
|
|
@ -424,6 +433,13 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
logger.debug("Weave API check failed: %s", str(e))
|
||||
raise ValueError(f"Weave API check failed: {str(e)}")
|
||||
|
||||
def _normalize_time(self, dt: datetime | None) -> datetime:
|
||||
if dt is None:
|
||||
return datetime.now(UTC)
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=UTC)
|
||||
return dt
|
||||
|
||||
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
|
||||
inputs = run_data.inputs
|
||||
if inputs is None:
|
||||
|
|
@ -437,19 +453,71 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
elif not isinstance(attributes, dict):
|
||||
attributes = {"attributes": str(attributes)}
|
||||
|
||||
call = self.weave_client.create_call(
|
||||
op=run_data.op,
|
||||
inputs=inputs,
|
||||
attributes=attributes,
|
||||
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
|
||||
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
|
||||
trace_id = attributes.get("trace_id") if isinstance(attributes, dict) else None
|
||||
if trace_id is None:
|
||||
trace_id = run_data.id
|
||||
|
||||
call_start_req = CallStartReq(
|
||||
start=StartedCallSchemaForInsert(
|
||||
project_id=self.project_id,
|
||||
id=run_data.id,
|
||||
op_name=str(run_data.op),
|
||||
trace_id=trace_id,
|
||||
parent_id=parent_run_id,
|
||||
started_at=started_at,
|
||||
attributes=attributes,
|
||||
inputs=inputs,
|
||||
wb_user_id=None,
|
||||
)
|
||||
)
|
||||
self.calls[run_data.id] = call
|
||||
if parent_run_id:
|
||||
self.calls[run_data.id].parent_id = parent_run_id
|
||||
self.weave_client.server.call_start(call_start_req)
|
||||
self.calls[run_data.id] = {"trace_id": trace_id, "parent_id": parent_run_id}
|
||||
|
||||
def finish_call(self, run_data: WeaveTraceModel):
|
||||
call = self.calls.get(run_data.id)
|
||||
if call:
|
||||
exception = Exception(run_data.exception) if run_data.exception else None
|
||||
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=exception)
|
||||
else:
|
||||
call_meta = self.calls.get(run_data.id)
|
||||
if not call_meta:
|
||||
raise ValueError(f"Call with id {run_data.id} not found")
|
||||
|
||||
attributes = run_data.attributes
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
elif not isinstance(attributes, dict):
|
||||
attributes = {"attributes": str(attributes)}
|
||||
|
||||
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
|
||||
end_time = attributes.get("end_time") if isinstance(attributes, dict) else None
|
||||
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
|
||||
ended_at = self._normalize_time(end_time if isinstance(end_time, datetime) else None)
|
||||
elapsed_ms = int((ended_at - started_at).total_seconds() * 1000)
|
||||
if elapsed_ms < 0:
|
||||
elapsed_ms = 0
|
||||
|
||||
status_counts = {
|
||||
TraceStatus.SUCCESS: 0,
|
||||
TraceStatus.ERROR: 0,
|
||||
}
|
||||
if run_data.exception:
|
||||
status_counts[TraceStatus.ERROR] = 1
|
||||
else:
|
||||
status_counts[TraceStatus.SUCCESS] = 1
|
||||
|
||||
summary: dict[str, Any] = {
|
||||
"status_counts": status_counts,
|
||||
"weave": {"latency_ms": elapsed_ms},
|
||||
}
|
||||
|
||||
exception_str = str(run_data.exception) if run_data.exception else None
|
||||
|
||||
call_end_req = CallEndReq(
|
||||
end=EndedCallSchemaForInsert(
|
||||
project_id=self.project_id,
|
||||
id=run_data.id,
|
||||
ended_at=ended_at,
|
||||
exception=exception_str,
|
||||
output=run_data.outputs,
|
||||
summary=cast(SummaryInsertMap, summary),
|
||||
)
|
||||
)
|
||||
self.weave_client.server.call_end(call_end_req)
|
||||
|
|
|
|||
|
|
@ -309,11 +309,12 @@ class ProviderManager:
|
|||
(model for model in available_models if model.model == "gpt-4"), available_models[0]
|
||||
)
|
||||
|
||||
default_model = TenantDefaultModel()
|
||||
default_model.tenant_id = tenant_id
|
||||
default_model.model_type = model_type.to_origin_model_type()
|
||||
default_model.provider_name = available_model.provider.provider
|
||||
default_model.model_name = available_model.model
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
provider_name=available_model.provider.provider,
|
||||
model_name=available_model.model,
|
||||
)
|
||||
db.session.add(default_model)
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,18 @@ logger = logging.getLogger(__name__)
|
|||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
T = TypeVar("T", bound="MatrixoneVector")
|
||||
|
||||
|
||||
def ensure_client(func: Callable[Concatenate[T, P], R]):
|
||||
@wraps(func)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
if self.client is None:
|
||||
self.client = self._get_client(None, False)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class MatrixoneConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
|
|
@ -206,19 +218,6 @@ class MatrixoneVector(BaseVector):
|
|||
self.client.delete()
|
||||
|
||||
|
||||
T = TypeVar("T", bound=MatrixoneVector)
|
||||
|
||||
|
||||
def ensure_client(func: Callable[Concatenate[T, P], R]):
|
||||
@wraps(func)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
if self.client is None:
|
||||
self.client = self._get_client(None, False)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class MatrixoneVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
|
||||
if dataset.index_struct_dict:
|
||||
|
|
|
|||
|
|
@ -152,13 +152,15 @@ class WordExtractor(BaseExtractor):
|
|||
# Initialize a row, all of which are empty by default
|
||||
row_cells = [""] * total_cols
|
||||
col_index = 0
|
||||
for cell in row.cells:
|
||||
while col_index < len(row.cells):
|
||||
# make sure the col_index is not out of range
|
||||
while col_index < total_cols and row_cells[col_index] != "":
|
||||
while col_index < len(row.cells) and row_cells[col_index] != "":
|
||||
col_index += 1
|
||||
# if col_index is out of range the loop is jumped
|
||||
if col_index >= total_cols:
|
||||
if col_index >= len(row.cells):
|
||||
break
|
||||
# get the correct cell
|
||||
cell = row.cells[col_index]
|
||||
cell_content = self._parse_cell(cell, image_map).strip()
|
||||
cell_colspan = cell.grid_span or 1
|
||||
for i in range(cell_colspan):
|
||||
|
|
|
|||
|
|
@ -54,6 +54,9 @@ class TenantIsolatedTaskQueue:
|
|||
serialized_data = wrapper.serialize()
|
||||
serialized_tasks.append(serialized_data)
|
||||
|
||||
if not serialized_tasks:
|
||||
return
|
||||
|
||||
redis_client.lpush(self._queue, *serialized_tasks)
|
||||
|
||||
def pull_tasks(self, count: int = 1) -> Sequence[Any]:
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@ from collections.abc import Generator, Mapping
|
|||
from typing import Any, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import Float, and_, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy import and_, or_, select
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
|
|
@ -1023,60 +1022,55 @@ class DatasetRetrieval:
|
|||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
):
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return
|
||||
return filters
|
||||
|
||||
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
|
||||
|
||||
key = f"{metadata_name}_{sequence}"
|
||||
key_value = f"{metadata_name}_{sequence}_value"
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
|
||||
case "start with":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
|
||||
case "end with":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) == value)
|
||||
filters.append(json_field == value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() == value)
|
||||
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) != value)
|
||||
filters.append(json_field != value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() != value)
|
||||
|
||||
case "empty":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
|
||||
|
||||
case "not empty":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
|
||||
|
||||
case "before" | "<":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) < value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() < value)
|
||||
|
||||
case "after" | ">":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) > value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() > value)
|
||||
|
||||
case "≤" | "<=":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) <= value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() <= value)
|
||||
|
||||
case "≥" | ">=":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) >= value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
||||
case _:
|
||||
pass
|
||||
|
||||
return filters
|
||||
|
||||
def _fetch_model_config(
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
|
|||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
|
@ -32,7 +33,6 @@ from services.tools.mcp_tools_manage_service import MCPToolManageService
|
|||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
|
|
@ -63,7 +63,6 @@ from services.tools.tools_transform_service import ToolTransformService
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -618,12 +617,28 @@ class ToolManager:
|
|||
"""
|
||||
# according to multi credentials, select the one with is_default=True first, then created_at oldest
|
||||
# for compatibility with old version
|
||||
sql = """
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
# PostgreSQL: Use DISTINCT ON
|
||||
sql = """
|
||||
SELECT DISTINCT ON (tenant_id, provider) id
|
||||
FROM tool_builtin_providers
|
||||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
else:
|
||||
# MySQL: Use window function to achieve same result
|
||||
sql = """
|
||||
SELECT id FROM (
|
||||
SELECT id,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY tenant_id, provider
|
||||
ORDER BY is_default DESC, created_at DESC
|
||||
) as rn
|
||||
FROM tool_builtin_providers
|
||||
WHERE tenant_id = :tenant_id
|
||||
) ranked WHERE rn = 1
|
||||
"""
|
||||
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.file.models import File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class ArrayValidation(StrEnum):
|
||||
"""Strategy for validating array elements.
|
||||
|
|
@ -155,6 +158,17 @@ class SegmentType(StrEnum):
|
|||
return isinstance(value, File)
|
||||
elif self == SegmentType.NONE:
|
||||
return value is None
|
||||
elif self == SegmentType.GROUP:
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import Segment
|
||||
|
||||
if isinstance(value, SegmentGroup):
|
||||
return all(isinstance(item, Segment) for item in value.value)
|
||||
|
||||
if isinstance(value, list):
|
||||
return all(isinstance(item, Segment) for item in value)
|
||||
|
||||
return False
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
|
|
@ -202,6 +216,35 @@ class SegmentType(StrEnum):
|
|||
raise ValueError(f"element_type is only supported by array type, got {self}")
|
||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||
|
||||
@staticmethod
|
||||
def get_zero_value(t: "SegmentType"):
|
||||
# Lazy import to avoid circular dependency
|
||||
from factories import variable_factory
|
||||
|
||||
match t:
|
||||
case (
|
||||
SegmentType.ARRAY_OBJECT
|
||||
| SegmentType.ARRAY_ANY
|
||||
| SegmentType.ARRAY_STRING
|
||||
| SegmentType.ARRAY_NUMBER
|
||||
| SegmentType.ARRAY_BOOLEAN
|
||||
):
|
||||
return variable_factory.build_segment_with_type(t, [])
|
||||
case SegmentType.OBJECT:
|
||||
return variable_factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
return variable_factory.build_segment("")
|
||||
case SegmentType.INTEGER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.FLOAT:
|
||||
return variable_factory.build_segment(0.0)
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.BOOLEAN:
|
||||
return variable_factory.build_segment(False)
|
||||
case _:
|
||||
raise ValueError(f"unsupported variable type: {t}")
|
||||
|
||||
|
||||
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
||||
# ARRAY_ANY does not have corresponding element type.
|
||||
|
|
|
|||
|
|
@ -192,7 +192,6 @@ class GraphEngine:
|
|||
self._dispatcher = Dispatcher(
|
||||
event_queue=self._event_queue,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_manager,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
event_emitter=self._event_manager,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ class Dispatcher:
|
|||
self,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
event_collector: EventManager,
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
|
|
@ -53,13 +52,11 @@ class Dispatcher:
|
|||
Args:
|
||||
event_queue: Queue of events from workers
|
||||
event_handler: Event handler registry for processing events
|
||||
event_collector: Event manager for collecting unhandled events
|
||||
execution_coordinator: Coordinator for execution flow
|
||||
event_emitter: Optional event manager to signal completion
|
||||
"""
|
||||
self._event_queue = event_queue
|
||||
self._event_handler = event_handler
|
||||
self._event_collector = event_collector
|
||||
self._execution_coordinator = execution_coordinator
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
|
|
@ -86,37 +83,31 @@ class Dispatcher:
|
|||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
try:
|
||||
self._process_commands()
|
||||
while not self._stop_event.is_set():
|
||||
commands_checked = False
|
||||
should_check_commands = False
|
||||
should_break = False
|
||||
if (
|
||||
self._execution_coordinator.aborted
|
||||
or self._execution_coordinator.paused
|
||||
or self._execution_coordinator.execution_complete
|
||||
):
|
||||
break
|
||||
|
||||
if self._execution_coordinator.is_execution_complete():
|
||||
should_check_commands = True
|
||||
should_break = True
|
||||
else:
|
||||
# Check for scaling
|
||||
self._execution_coordinator.check_scaling()
|
||||
self._execution_coordinator.check_scaling()
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
self._process_commands(event)
|
||||
except queue.Empty:
|
||||
time.sleep(0.1)
|
||||
|
||||
# Process events
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
# Route to the event handler
|
||||
self._event_handler.dispatch(event)
|
||||
should_check_commands = self._should_check_commands(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
# Process commands even when no new events arrive so abort requests are not missed
|
||||
should_check_commands = True
|
||||
time.sleep(0.1)
|
||||
|
||||
if should_check_commands and not commands_checked:
|
||||
self._execution_coordinator.check_commands()
|
||||
commands_checked = True
|
||||
|
||||
if should_break:
|
||||
if not commands_checked:
|
||||
self._execution_coordinator.check_commands()
|
||||
self._process_commands()
|
||||
while True:
|
||||
try:
|
||||
event = self._event_queue.get(block=False)
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -129,6 +120,6 @@ class Dispatcher:
|
|||
if self._event_emitter:
|
||||
self._event_emitter.mark_complete()
|
||||
|
||||
def _should_check_commands(self, event: GraphNodeEventBase) -> bool:
|
||||
"""Return True if the event represents a node completion."""
|
||||
return isinstance(event, self._COMMAND_TRIGGER_EVENTS)
|
||||
def _process_commands(self, event: GraphNodeEventBase | None = None):
|
||||
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
|
||||
self._execution_coordinator.process_commands()
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class ExecutionCoordinator:
|
|||
self._command_processor = command_processor
|
||||
self._worker_pool = worker_pool
|
||||
|
||||
def check_commands(self) -> None:
|
||||
def process_commands(self) -> None:
|
||||
"""Process any pending commands."""
|
||||
self._command_processor.process_commands()
|
||||
|
||||
|
|
@ -48,24 +48,16 @@ class ExecutionCoordinator:
|
|||
"""Check and perform worker scaling if needed."""
|
||||
self._worker_pool.check_and_scale()
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
"""
|
||||
Check if execution is complete.
|
||||
|
||||
Returns:
|
||||
True if execution is complete
|
||||
"""
|
||||
# Treat paused, aborted, or failed executions as terminal states
|
||||
if self._graph_execution.is_paused:
|
||||
return True
|
||||
|
||||
if self._graph_execution.aborted or self._graph_execution.has_error:
|
||||
return True
|
||||
|
||||
@property
|
||||
def execution_complete(self):
|
||||
return self._state_manager.is_execution_complete()
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
def aborted(self):
|
||||
return self._graph_execution.aborted or self._graph_execution.has_error
|
||||
|
||||
@property
|
||||
def paused(self) -> bool:
|
||||
"""Expose whether the underlying graph execution is paused."""
|
||||
return self._graph_execution.is_paused
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ from collections import defaultdict
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import Float, and_, func, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
|
|
@ -597,79 +596,79 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
|
|||
if value is None and condition not in ("empty", "not empty"):
|
||||
return filters
|
||||
|
||||
key = f"{metadata_name}_{sequence}"
|
||||
key_value = f"{metadata_name}_{sequence}_value"
|
||||
json_field = Document.doc_metadata[metadata_name].as_string()
|
||||
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
|
||||
case "start with":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
|
||||
case "end with":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
case "in":
|
||||
if isinstance(value, str):
|
||||
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
|
||||
escaped_value_str = ",".join(escaped_values)
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
escaped_value_str = str(value)
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} = any(string_to_array(:{key_value},','))")).params(
|
||||
**{key: metadata_name, key_value: escaped_value_str}
|
||||
)
|
||||
)
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(False))
|
||||
else:
|
||||
filters.append(json_field.in_(value_list))
|
||||
|
||||
case "not in":
|
||||
if isinstance(value, str):
|
||||
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
|
||||
escaped_value_str = ",".join(escaped_values)
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
escaped_value_str = str(value)
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} != all(string_to_array(:{key_value},','))")).params(
|
||||
**{key: metadata_name, key_value: escaped_value_str}
|
||||
)
|
||||
)
|
||||
case "=" | "is":
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(True))
|
||||
else:
|
||||
filters.append(json_field.notin_(value_list))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) == value)
|
||||
filters.append(json_field == value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
|
||||
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) != value)
|
||||
filters.append(json_field != value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
|
||||
|
||||
case "empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
||||
|
||||
case "not empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
||||
|
||||
case "before" | "<":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) < value)
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
|
||||
|
||||
case "after" | ">":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) > value)
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
|
||||
|
||||
case "≤" | "<=":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) <= value)
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
|
||||
|
||||
case "≥" | ">=":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) >= value)
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
return filters
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from collections.abc import Callable, Mapping, Sequence
|
|||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.segments import BooleanSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities import GraphInitParams
|
||||
|
|
@ -12,7 +11,6 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from factories import variable_factory
|
||||
|
||||
from ..common.impl import conversation_variable_updater_factory
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
|
@ -116,7 +114,7 @@ class VariableAssignerNode(Node):
|
|||
updated_variable = original_variable.model_copy(update={"value": updated_value})
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = get_zero_value(original_variable.value_type)
|
||||
income_value = SegmentType.get_zero_value(original_variable.value_type)
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
|
||||
# Over write the variable.
|
||||
|
|
@ -143,24 +141,3 @@ class VariableAssignerNode(Node):
|
|||
process_data=common_helpers.set_updated_variables({}, updated_variables),
|
||||
outputs={},
|
||||
)
|
||||
|
||||
|
||||
def get_zero_value(t: SegmentType):
|
||||
# TODO(QuantumGhost): this should be a method of `SegmentType`.
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
|
||||
return variable_factory.build_segment_with_type(t, [])
|
||||
case SegmentType.OBJECT:
|
||||
return variable_factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
return variable_factory.build_segment("")
|
||||
case SegmentType.INTEGER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.FLOAT:
|
||||
return variable_factory.build_segment(0.0)
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.BOOLEAN:
|
||||
return BooleanSegment(value=False)
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported variable type: {t}")
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
from core.variables import SegmentType
|
||||
|
||||
# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy.
|
||||
EMPTY_VALUE_MAPPING = {
|
||||
SegmentType.STRING: "",
|
||||
SegmentType.NUMBER: 0,
|
||||
SegmentType.BOOLEAN: False,
|
||||
SegmentType.OBJECT: {},
|
||||
SegmentType.ARRAY_ANY: [],
|
||||
SegmentType.ARRAY_STRING: [],
|
||||
SegmentType.ARRAY_NUMBER: [],
|
||||
SegmentType.ARRAY_OBJECT: [],
|
||||
SegmentType.ARRAY_BOOLEAN: [],
|
||||
}
|
||||
|
|
@ -16,7 +16,6 @@ from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNod
|
|||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
|
||||
from . import helpers
|
||||
from .constants import EMPTY_VALUE_MAPPING
|
||||
from .entities import VariableAssignerNodeData, VariableOperationItem
|
||||
from .enums import InputType, Operation
|
||||
from .exc import (
|
||||
|
|
@ -249,7 +248,7 @@ class VariableAssignerNode(Node):
|
|||
case Operation.OVER_WRITE:
|
||||
return value
|
||||
case Operation.CLEAR:
|
||||
return EMPTY_VALUE_MAPPING[variable.value_type]
|
||||
return SegmentType.get_zero_value(variable.value_type).to_object()
|
||||
case Operation.APPEND:
|
||||
return variable.value + [value]
|
||||
case Operation.EXTEND:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
import importlib
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Mapping as TypingMapping
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
|
@ -100,8 +99,8 @@ class ResponseStreamCoordinatorProtocol(Protocol):
|
|||
class GraphProtocol(Protocol):
|
||||
"""Structural interface required from graph instances attached to the runtime state."""
|
||||
|
||||
nodes: TypingMapping[str, object]
|
||||
edges: TypingMapping[str, object]
|
||||
nodes: Mapping[str, object]
|
||||
edges: Mapping[str, object]
|
||||
root_node: object
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
|
||||
|
|
|
|||
|
|
@ -265,6 +265,45 @@ def _assert_not_empty(*, value: object) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _normalize_numeric_values(value: int | float, expected: object) -> tuple[int | float, int | float]:
|
||||
"""
|
||||
Normalize value and expected to compatible numeric types for comparison.
|
||||
|
||||
Args:
|
||||
value: The actual numeric value (int or float)
|
||||
expected: The expected value (int, float, or str)
|
||||
|
||||
Returns:
|
||||
A tuple of (normalized_value, normalized_expected) with compatible types
|
||||
|
||||
Raises:
|
||||
ValueError: If expected cannot be converted to a number
|
||||
"""
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to number")
|
||||
|
||||
# Convert expected to appropriate numeric type
|
||||
if isinstance(expected, str):
|
||||
# Try to convert to float first to handle decimal strings
|
||||
try:
|
||||
expected_float = float(expected)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Cannot convert '{expected}' to number") from e
|
||||
|
||||
# If value is int and expected is a whole number, keep as int comparison
|
||||
if isinstance(value, int) and expected_float.is_integer():
|
||||
return value, int(expected_float)
|
||||
else:
|
||||
# Otherwise convert value to float for comparison
|
||||
return float(value) if isinstance(value, int) else value, expected_float
|
||||
elif isinstance(expected, float):
|
||||
# If expected is already float, convert int value to float
|
||||
return float(value) if isinstance(value, int) else value, expected
|
||||
else:
|
||||
# expected is int
|
||||
return value, expected
|
||||
|
||||
|
||||
def _assert_equal(*, value: object, expected: object) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
|
@ -324,18 +363,8 @@ def _assert_greater_than(*, value: object, expected: object) -> bool:
|
|||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value <= expected:
|
||||
return False
|
||||
return True
|
||||
value, expected = _normalize_numeric_values(value, expected)
|
||||
return value > expected
|
||||
|
||||
|
||||
def _assert_less_than(*, value: object, expected: object) -> bool:
|
||||
|
|
@ -345,18 +374,8 @@ def _assert_less_than(*, value: object, expected: object) -> bool:
|
|||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value >= expected:
|
||||
return False
|
||||
return True
|
||||
value, expected = _normalize_numeric_values(value, expected)
|
||||
return value < expected
|
||||
|
||||
|
||||
def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool:
|
||||
|
|
@ -366,18 +385,8 @@ def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool:
|
|||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value < expected:
|
||||
return False
|
||||
return True
|
||||
value, expected = _normalize_numeric_values(value, expected)
|
||||
return value >= expected
|
||||
|
||||
|
||||
def _assert_less_than_or_equal(*, value: object, expected: object) -> bool:
|
||||
|
|
@ -387,18 +396,8 @@ def _assert_less_than_or_equal(*, value: object, expected: object) -> bool:
|
|||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value > expected:
|
||||
return False
|
||||
return True
|
||||
value, expected = _normalize_numeric_values(value, expected)
|
||||
return value <= expected
|
||||
|
||||
|
||||
def _assert_null(*, value: object) -> bool:
|
||||
|
|
|
|||
|
|
@ -421,4 +421,10 @@ class WorkflowEntry:
|
|||
if len(variable_key_list) == 2 and variable_key_list[0] == "structured_output":
|
||||
input_value = {variable_key_list[1]: input_value}
|
||||
variable_key_list = variable_key_list[0:1]
|
||||
|
||||
# Support for a single node to reference multiple structured_output variables
|
||||
current_variable = variable_pool.get([variable_node_id] + variable_key_list)
|
||||
if current_variable and isinstance(current_variable.value, dict):
|
||||
input_value = current_variable.value | input_value
|
||||
|
||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,209 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum, auto
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCharge:
|
||||
"""
|
||||
Result of a quota consumption operation.
|
||||
|
||||
Attributes:
|
||||
success: Whether the quota charge succeeded
|
||||
charge_id: UUID for refund, or None if failed/disabled
|
||||
"""
|
||||
|
||||
success: bool
|
||||
charge_id: str | None
|
||||
_quota_type: "QuotaType"
|
||||
|
||||
def refund(self) -> None:
|
||||
"""
|
||||
Refund this quota charge.
|
||||
|
||||
Safe to call even if charge failed or was disabled.
|
||||
This method guarantees no exceptions will be raised.
|
||||
"""
|
||||
if self.charge_id:
|
||||
self._quota_type.refund(self.charge_id)
|
||||
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
|
||||
|
||||
|
||||
class QuotaType(StrEnum):
|
||||
"""
|
||||
Supported quota types for tenant feature usage.
|
||||
|
||||
Add additional types here whenever new billable features become available.
|
||||
"""
|
||||
|
||||
# Trigger execution quota
|
||||
TRIGGER = auto()
|
||||
|
||||
# Workflow execution quota
|
||||
WORKFLOW = auto()
|
||||
|
||||
UNLIMITED = auto()
|
||||
|
||||
@property
|
||||
def billing_key(self) -> str:
|
||||
"""
|
||||
Get the billing key for the feature.
|
||||
"""
|
||||
match self:
|
||||
case QuotaType.TRIGGER:
|
||||
return "trigger_event"
|
||||
case QuotaType.WORKFLOW:
|
||||
return "api_rate_limit"
|
||||
case _:
|
||||
raise ValueError(f"Invalid quota type: {self}")
|
||||
|
||||
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Consume quota for the feature.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to consume (default: 1)
|
||||
|
||||
Returns:
|
||||
QuotaCharge with success status and charge_id for refund
|
||||
|
||||
Raises:
|
||||
QuotaExceededError: When quota is insufficient
|
||||
"""
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
|
||||
|
||||
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to consume must be greater than 0")
|
||||
|
||||
try:
|
||||
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
|
||||
|
||||
if response.get("result") != "success":
|
||||
logger.warning(
|
||||
"Failed to consume quota for %s, feature %s details: %s",
|
||||
tenant_id,
|
||||
self.value,
|
||||
response.get("detail"),
|
||||
)
|
||||
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
|
||||
|
||||
charge_id = response.get("history_id")
|
||||
logger.debug(
|
||||
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
|
||||
amount,
|
||||
self.value,
|
||||
tenant_id,
|
||||
charge_id,
|
||||
)
|
||||
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
# fail-safe: allow request on billing errors
|
||||
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
|
||||
return unlimited()
|
||||
|
||||
def check(self, tenant_id: str, amount: int = 1) -> bool:
|
||||
"""
|
||||
Check if tenant has sufficient quota without consuming.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to check (default: 1)
|
||||
|
||||
Returns:
|
||||
True if quota is sufficient, False otherwise
|
||||
"""
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return True
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to check must be greater than 0")
|
||||
|
||||
try:
|
||||
remaining = self.get_remaining(tenant_id)
|
||||
return remaining >= amount if remaining != -1 else True
|
||||
except Exception:
|
||||
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
|
||||
# fail-safe: allow request on billing errors
|
||||
return True
|
||||
|
||||
def refund(self, charge_id: str) -> None:
|
||||
"""
|
||||
Refund quota using charge_id from consume().
|
||||
|
||||
This method guarantees no exceptions will be raised.
|
||||
All errors are logged but silently handled.
|
||||
|
||||
Args:
|
||||
charge_id: The UUID returned from consume()
|
||||
"""
|
||||
try:
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return
|
||||
|
||||
if not charge_id:
|
||||
logger.warning("Cannot refund: charge_id is empty")
|
||||
return
|
||||
|
||||
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
|
||||
|
||||
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
|
||||
if response.get("result") == "success":
|
||||
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
|
||||
else:
|
||||
logger.warning("Refund failed for charge_id: %s", charge_id)
|
||||
|
||||
except Exception:
|
||||
# Catch ALL exceptions - refund must never fail
|
||||
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
|
||||
# Don't raise - refund is best-effort and must be silent
|
||||
|
||||
def get_remaining(self, tenant_id: str) -> int:
|
||||
"""
|
||||
Get remaining quota for the tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
Remaining quota amount
|
||||
"""
|
||||
from services.billing_service import BillingService
|
||||
|
||||
try:
|
||||
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
|
||||
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
|
||||
if isinstance(usage_info, dict):
|
||||
return usage_info.get("remaining", 0)
|
||||
# If it returns a simple number, treat it as remaining
|
||||
return int(usage_info) if usage_info else 0
|
||||
except Exception:
|
||||
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
|
||||
return -1
|
||||
|
||||
|
||||
def unlimited() -> QuotaCharge:
|
||||
"""
|
||||
Return a quota charge for unlimited quota.
|
||||
|
||||
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
|
||||
"""
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
||||
|
|
@ -10,7 +10,6 @@ from redis import RedisError
|
|||
from redis.cache import CacheConfig
|
||||
from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
from redis.lock import Lock
|
||||
from redis.sentinel import Sentinel
|
||||
|
||||
from configs import dify_config
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ class ClickZettaVolumeConfig(BaseModel):
|
|||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||
then fall back to CLICKZETTA_* environment variables (for vector DB config).
|
||||
"""
|
||||
import os
|
||||
|
||||
# Helper function to get environment variable with fallback
|
||||
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import io
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from google.cloud import storage as google_cloud_storage
|
||||
from google.cloud import storage as google_cloud_storage # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ app_partial_fields = {
|
|||
"access_mode": fields.String,
|
||||
"create_user_name": fields.String,
|
||||
"author_name": fields.String,
|
||||
"has_draft_trigger": fields.Boolean,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ dataset_detail_fields = {
|
|||
"document_count": fields.Integer,
|
||||
"word_count": fields.Integer,
|
||||
"created_by": fields.String,
|
||||
"author_name": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .channel import BroadcastChannel
|
||||
from .sharded_channel import ShardedRedisBroadcastChannel
|
||||
|
||||
__all__ = ["BroadcastChannel"]
|
||||
__all__ = ["BroadcastChannel", "ShardedRedisBroadcastChannel"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,205 @@
|
|||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import types
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Self
|
||||
|
||||
from libs.broadcast_channel.channel import Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis.client import PubSub
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisSubscriptionBase(Subscription):
|
||||
"""Base class for Redis pub/sub subscriptions with common functionality.
|
||||
|
||||
This class provides shared functionality for both regular and sharded
|
||||
Redis pub/sub subscriptions, reducing code duplication and improving
|
||||
maintainability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._pubsub: PubSub | None = pubsub
|
||||
self._topic = topic
|
||||
self._closed = threading.Event()
|
||||
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
|
||||
self._dropped_count = 0
|
||||
self._listener_thread: threading.Thread | None = None
|
||||
self._start_lock = threading.Lock()
|
||||
self._started = False
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
"""Start the subscription if not already started."""
|
||||
with self._start_lock:
|
||||
if self._started:
|
||||
return
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
if self._pubsub is None:
|
||||
raise SubscriptionClosedError(
|
||||
f"The Redis {self._get_subscription_type()} subscription has been cleaned up"
|
||||
)
|
||||
|
||||
self._subscribe()
|
||||
_logger.debug("Subscribed to %s channel %s", self._get_subscription_type(), self._topic)
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self._listen,
|
||||
name=f"redis-{self._get_subscription_type().replace(' ', '-')}-broadcast-{self._topic}",
|
||||
daemon=True,
|
||||
)
|
||||
self._listener_thread.start()
|
||||
self._started = True
|
||||
|
||||
def _listen(self) -> None:
|
||||
"""Main listener loop for processing messages."""
|
||||
pubsub = self._pubsub
|
||||
assert pubsub is not None, "PubSub should not be None while starting listening."
|
||||
while not self._closed.is_set():
|
||||
raw_message = self._get_message()
|
||||
|
||||
if raw_message is None:
|
||||
continue
|
||||
|
||||
if raw_message.get("type") != self._get_message_type():
|
||||
continue
|
||||
|
||||
channel_field = raw_message.get("channel")
|
||||
if isinstance(channel_field, bytes):
|
||||
channel_name = channel_field.decode("utf-8")
|
||||
elif isinstance(channel_field, str):
|
||||
channel_name = channel_field
|
||||
else:
|
||||
channel_name = str(channel_field)
|
||||
|
||||
if channel_name != self._topic:
|
||||
_logger.warning(
|
||||
"Ignoring %s message from unexpected channel %s", self._get_subscription_type(), channel_name
|
||||
)
|
||||
continue
|
||||
|
||||
payload_bytes: bytes | None = raw_message.get("data")
|
||||
if not isinstance(payload_bytes, bytes):
|
||||
_logger.error(
|
||||
"Received invalid data from %s channel %s, type=%s",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
type(payload_bytes),
|
||||
)
|
||||
continue
|
||||
|
||||
self._enqueue_message(payload_bytes)
|
||||
|
||||
_logger.debug("%s listener thread stopped for channel %s", self._get_subscription_type().title(), self._topic)
|
||||
self._unsubscribe()
|
||||
pubsub.close()
|
||||
_logger.debug("%s PubSub closed for topic %s", self._get_subscription_type().title(), self._topic)
|
||||
self._pubsub = None
|
||||
|
||||
def _enqueue_message(self, payload: bytes) -> None:
|
||||
"""Enqueue a message to the internal queue with dropping behavior."""
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
self._queue.put_nowait(payload)
|
||||
return
|
||||
except queue.Full:
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
self._dropped_count += 1
|
||||
_logger.debug(
|
||||
"Dropped message from Redis %s subscription, topic=%s, total_dropped=%d",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
self._dropped_count,
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
return
|
||||
|
||||
def _message_iterator(self) -> Generator[bytes, None, None]:
|
||||
"""Iterator for consuming messages from the subscription."""
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
item = self._queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
"""Return an iterator over messages from the subscription."""
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||
"""Receive the next message from the subscription."""
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
self._start_if_needed()
|
||||
|
||||
try:
|
||||
item = self._queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
return item
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""Context manager entry point."""
|
||||
self._start_if_needed()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
"""Context manager exit point."""
|
||||
self.close()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the subscription and clean up resources."""
|
||||
if self._closed.is_set():
|
||||
return
|
||||
|
||||
self._closed.set()
|
||||
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the
|
||||
# message retrieval method should NOT be called concurrently.
|
||||
#
|
||||
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||
listener = self._listener_thread
|
||||
if listener is not None:
|
||||
listener.join(timeout=1.0)
|
||||
self._listener_thread = None
|
||||
|
||||
# Abstract methods to be implemented by subclasses
|
||||
def _get_subscription_type(self) -> str:
|
||||
"""Return the subscription type (e.g., 'regular' or 'sharded')."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _subscribe(self) -> None:
|
||||
"""Subscribe to the Redis topic using the appropriate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _unsubscribe(self) -> None:
|
||||
"""Unsubscribe from the Redis topic using the appropriate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
"""Get a message from Redis using the appropriate method."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
"""Return the expected message type (e.g., 'message' or 'smessage')."""
|
||||
raise NotImplementedError
|
||||
|
|
@ -1,24 +1,15 @@
|
|||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import types
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Self
|
||||
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis import Redis
|
||||
from redis.client import PubSub
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
|
||||
class BroadcastChannel:
|
||||
"""
|
||||
Redis Pub/Sub based broadcast channel implementation.
|
||||
Redis Pub/Sub based broadcast channel implementation (regular, non-sharded).
|
||||
|
||||
Provides "at most once" delivery semantics for messages published to channels.
|
||||
Uses Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
|
||||
Provides "at most once" delivery semantics for messages published to channels
|
||||
using Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
|
||||
|
||||
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
|
||||
"""
|
||||
|
|
@ -54,147 +45,23 @@ class Topic:
|
|||
)
|
||||
|
||||
|
||||
class _RedisSubscription(Subscription):
|
||||
def __init__(
|
||||
self,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._pubsub: PubSub | None = pubsub
|
||||
self._topic = topic
|
||||
self._closed = threading.Event()
|
||||
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
|
||||
self._dropped_count = 0
|
||||
self._listener_thread: threading.Thread | None = None
|
||||
self._start_lock = threading.Lock()
|
||||
self._started = False
|
||||
class _RedisSubscription(RedisSubscriptionBase):
|
||||
"""Regular Redis pub/sub subscription implementation."""
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
with self._start_lock:
|
||||
if self._started:
|
||||
return
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
if self._pubsub is None:
|
||||
raise SubscriptionClosedError("The Redis subscription has been cleaned up")
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "regular"
|
||||
|
||||
self._pubsub.subscribe(self._topic)
|
||||
_logger.debug("Subscribed to channel %s", self._topic)
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.subscribe(self._topic)
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self._listen,
|
||||
name=f"redis-broadcast-{self._topic}",
|
||||
daemon=True,
|
||||
)
|
||||
self._listener_thread.start()
|
||||
self._started = True
|
||||
def _unsubscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.unsubscribe(self._topic)
|
||||
|
||||
def _listen(self) -> None:
|
||||
pubsub = self._pubsub
|
||||
assert pubsub is not None, "PubSub should not be None while starting listening."
|
||||
while not self._closed.is_set():
|
||||
raw_message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
||||
def _get_message(self) -> dict | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
||||
|
||||
if raw_message is None:
|
||||
continue
|
||||
|
||||
if raw_message.get("type") != "message":
|
||||
continue
|
||||
|
||||
channel_field = raw_message.get("channel")
|
||||
if isinstance(channel_field, bytes):
|
||||
channel_name = channel_field.decode("utf-8")
|
||||
elif isinstance(channel_field, str):
|
||||
channel_name = channel_field
|
||||
else:
|
||||
channel_name = str(channel_field)
|
||||
|
||||
if channel_name != self._topic:
|
||||
_logger.warning("Ignoring message from unexpected channel %s", channel_name)
|
||||
continue
|
||||
|
||||
payload_bytes: bytes | None = raw_message.get("data")
|
||||
if not isinstance(payload_bytes, bytes):
|
||||
_logger.error("Received invalid data from channel %s, type=%s", self._topic, type(payload_bytes))
|
||||
continue
|
||||
|
||||
self._enqueue_message(payload_bytes)
|
||||
|
||||
_logger.debug("Listener thread stopped for channel %s", self._topic)
|
||||
pubsub.unsubscribe(self._topic)
|
||||
pubsub.close()
|
||||
_logger.debug("PubSub closed for topic %s", self._topic)
|
||||
self._pubsub = None
|
||||
|
||||
def _enqueue_message(self, payload: bytes) -> None:
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
self._queue.put_nowait(payload)
|
||||
return
|
||||
except queue.Full:
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
self._dropped_count += 1
|
||||
_logger.debug(
|
||||
"Dropped message from Redis subscription, topic=%s, total_dropped=%d",
|
||||
self._topic,
|
||||
self._dropped_count,
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
return
|
||||
|
||||
def _message_iterator(self) -> Generator[bytes, None, None]:
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
item = self._queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
self._start_if_needed()
|
||||
|
||||
try:
|
||||
item = self._queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
return item
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self._start_if_needed()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
self.close()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed.is_set():
|
||||
return
|
||||
|
||||
self._closed.set()
|
||||
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the `PubSub.get_message`
|
||||
# method should NOT be called concurrently.
|
||||
#
|
||||
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||
listener = self._listener_thread
|
||||
if listener is not None:
|
||||
listener.join(timeout=1.0)
|
||||
self._listener_thread = None
|
||||
def _get_message_type(self) -> str:
|
||||
return "message"
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue