mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feature/tool-invocation-add-inputs-param
This commit is contained in:
commit
ff29b68b27
|
|
@ -29,7 +29,7 @@ trim_trailing_whitespace = false
|
|||
|
||||
# Matches multiple files with brace expansion notation
|
||||
# Set default charset
|
||||
[*.{js,tsx}]
|
||||
[*.{js,jsx,ts,tsx,mjs}]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ jobs:
|
|||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db
|
||||
db_postgres
|
||||
redis
|
||||
sandbox
|
||||
ssrf_proxy
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ concurrency:
|
|||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
db-migration-test:
|
||||
db-migration-test-postgres:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
|
|
@ -45,7 +45,7 @@ jobs:
|
|||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db
|
||||
db_postgres
|
||||
redis
|
||||
|
||||
- name: Prepare configs
|
||||
|
|
@ -57,3 +57,60 @@ jobs:
|
|||
env:
|
||||
DEBUG: true
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
||||
db-migration-test-mysql:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api
|
||||
- name: Ensure Offline migration are supported
|
||||
run: |
|
||||
# upgrade
|
||||
uv run --directory api flask db upgrade 'base:head' --sql
|
||||
# downgrade
|
||||
uv run --directory api flask db downgrade 'head:base' --sql
|
||||
|
||||
- name: Prepare middleware env for MySQL
|
||||
run: |
|
||||
cd docker
|
||||
cp middleware.env.example middleware.env
|
||||
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
|
||||
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db_mysql
|
||||
redis
|
||||
|
||||
- name: Prepare configs for MySQL
|
||||
run: |
|
||||
cd api
|
||||
cp .env.example .env
|
||||
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' .env
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
|
||||
|
||||
- name: Run DB Migration
|
||||
env:
|
||||
DEBUG: true
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
|
|
|||
|
|
@ -1,10 +1,7 @@
|
|||
name: Run VDB Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'api/core/rag/*.py'
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-${{ github.head_ref || github.run_id }}
|
||||
|
|
@ -54,13 +51,13 @@ jobs:
|
|||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Vector Store (TiDB)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: docker/tidb/docker-compose.yaml
|
||||
services: |
|
||||
tidb
|
||||
tiflash
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
# compose-file: docker/tidb/docker-compose.yaml
|
||||
# services: |
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
|
|
@ -86,8 +83,8 @@ jobs:
|
|||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
- name: Check VDB Ready (TiDB)
|
||||
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
|
|
|
|||
|
|
@ -186,6 +186,8 @@ docker/volumes/couchbase/*
|
|||
docker/volumes/oceanbase/*
|
||||
docker/volumes/plugin_daemon/*
|
||||
docker/volumes/matrixone/*
|
||||
docker/volumes/mysql/*
|
||||
docker/volumes/seekdb/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@
|
|||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
|
|
|||
8
Makefile
8
Makefile
|
|
@ -70,6 +70,11 @@ type-check:
|
|||
@uv run --directory api --dev basedpyright
|
||||
@echo "✅ Type check complete"
|
||||
|
||||
test:
|
||||
@echo "🧪 Running backend unit tests..."
|
||||
@uv run --project api --dev dev/pytest/pytest_unit_tests.sh
|
||||
@echo "✅ Tests complete"
|
||||
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
|
|
@ -119,6 +124,7 @@ help:
|
|||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format and fix code with ruff"
|
||||
@echo " make type-check - Run type checking with basedpyright"
|
||||
@echo " make test - Run backend unit tests"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
@echo " make build-web - Build web Docker image"
|
||||
|
|
@ -128,4 +134,4 @@ help:
|
|||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test
|
||||
|
|
|
|||
|
|
@ -72,12 +72,15 @@ REDIS_CLUSTERS_PASSWORD=
|
|||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
|
||||
CELERY_BACKEND=redis
|
||||
# PostgreSQL database configuration
|
||||
|
||||
# Database configuration
|
||||
DB_TYPE=postgresql
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_DATABASE=dify
|
||||
|
||||
SQLALCHEMY_POOL_PRE_PING=true
|
||||
SQLALCHEMY_POOL_TIMEOUT=30
|
||||
|
||||
|
|
@ -163,7 +166,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
|||
COOKIE_DOMAIN=
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
|
|
@ -174,6 +177,17 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
|||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
OCEANBASE_VECTOR_USER=root@test
|
||||
OCEANBASE_VECTOR_PASSWORD=difyai123456
|
||||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||
OCEANBASE_FULLTEXT_PARSER=ik
|
||||
SEEKDB_MEMORY_LIMIT=2G
|
||||
|
||||
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
|
||||
QDRANT_URL=http://localhost:6333
|
||||
QDRANT_API_KEY=difyai123456
|
||||
|
|
@ -339,15 +353,6 @@ LINDORM_PASSWORD=admin
|
|||
LINDORM_USING_UGC=True
|
||||
LINDORM_QUERY_TIMEOUT=1
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
OCEANBASE_VECTOR_USER=root@test
|
||||
OCEANBASE_VECTOR_PASSWORD=difyai123456
|
||||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||
|
||||
# AlibabaCloud MySQL Vector configuration
|
||||
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
|
||||
ALIBABACLOUD_MYSQL_PORT=3306
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@
|
|||
```bash
|
||||
cd ../docker
|
||||
cp middleware.env.example middleware.env
|
||||
# change the profile to other vector database if you are not using weaviate
|
||||
docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d
|
||||
# change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
|
||||
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
cd ../api
|
||||
```
|
||||
|
||||
|
|
@ -84,7 +84,7 @@
|
|||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
|
||||
```
|
||||
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||
"""
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
dify_app.config["RESTX_INCLUDE_ALL_MODELS"] = True
|
||||
|
||||
# add before request hook
|
||||
@dify_app.before_request
|
||||
|
|
|
|||
|
|
@ -77,10 +77,6 @@ class AppExecutionConfig(BaseSettings):
|
|||
description="Maximum number of concurrent active requests per app (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
APP_DAILY_RATE_LIMIT: NonNegativeInt = Field(
|
||||
description="Maximum number of requests per app per day",
|
||||
default=5000,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutionSandboxConfig(BaseSettings):
|
||||
|
|
@ -1086,7 +1082,7 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
|||
)
|
||||
TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive credential refresh threshold in seconds",
|
||||
default=180,
|
||||
default=60 * 60,
|
||||
)
|
||||
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive subscription refresh threshold in seconds",
|
||||
|
|
|
|||
|
|
@ -105,6 +105,12 @@ class KeywordStoreConfig(BaseSettings):
|
|||
|
||||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
# Database type selector
|
||||
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
|
||||
description="Database type to use. OceanBase is MySQL-compatible.",
|
||||
default="postgresql",
|
||||
)
|
||||
|
||||
DB_HOST: str = Field(
|
||||
description="Hostname or IP address of the database server.",
|
||||
default="localhost",
|
||||
|
|
@ -140,10 +146,10 @@ class DatabaseConfig(BaseSettings):
|
|||
default="",
|
||||
)
|
||||
|
||||
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
|
||||
description="Database URI scheme for SQLAlchemy connection.",
|
||||
default="postgresql",
|
||||
)
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
|
||||
return "postgresql" if self.DB_TYPE == "postgresql" else "mysql+pymysql"
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
|
|
@ -204,15 +210,15 @@ class DatabaseConfig(BaseSettings):
|
|||
# Parse DB_EXTRAS for 'options'
|
||||
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
|
||||
options = db_extras_dict.get("options", "")
|
||||
# Always include timezone
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
# Merge user options and timezone
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
|
||||
connect_args = {"options": merged_options}
|
||||
connect_args = {}
|
||||
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
|
||||
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
connect_args = {"options": merged_options}
|
||||
|
||||
return {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
|
|
|
|||
|
|
@ -104,14 +104,11 @@ class BaseApiKeyResource(Resource):
|
|||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
def delete(self, resource_id, api_key_id):
|
||||
def delete(self, resource_id: str, api_key_id: str):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
api_key_id = str(api_key_id)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import uuid
|
|||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, abort
|
||||
from werkzeug.exceptions import BadRequest, abort
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -12,6 +12,7 @@ from controllers.console.wraps import (
|
|||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
|
|
@ -250,10 +251,8 @@ class AppApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
# Construct ArgsDict from parsed arguments
|
||||
from services.app_service import AppService as AppServiceType
|
||||
|
||||
args_dict: AppServiceType.ArgsDict = {
|
||||
args_dict: AppService.ArgsDict = {
|
||||
"name": args["name"],
|
||||
"description": args.get("description", ""),
|
||||
"icon_type": args.get("icon_type", ""),
|
||||
|
|
@ -487,15 +486,11 @@ class AppApiStatus(Resource):
|
|||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,10 @@ from typing import cast
|
|||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
|
|
@ -48,15 +47,12 @@ class ModelConfigResource(Resource):
|
|||
@api.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model):
|
||||
"""Modify app model config"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_tenant_id,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,15 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -76,17 +81,13 @@ class AppSite(Resource):
|
|||
@api.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_fields)
|
||||
def post(self, app_model):
|
||||
args = parse_app_site_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
|
@ -130,16 +131,12 @@ class AppSiteAccessTokenReset(Resource):
|
|||
@api.response(404, "App or site not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.helper import DatetimeString, convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode, Message
|
||||
from models import AppMode
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
|
|
@ -44,8 +44,9 @@ class DailyMessageStatistic(Resource):
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(*) AS message_count
|
||||
FROM
|
||||
messages
|
||||
|
|
@ -106,6 +107,17 @@ class DailyConversationStatistic(Resource):
|
|||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(DISTINCT conversation_id) AS conversation_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
try:
|
||||
|
|
@ -113,30 +125,21 @@ class DailyConversationStatistic(Resource):
|
|||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
stmt = (
|
||||
sa.select(
|
||||
sa.func.date(
|
||||
sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
|
||||
).label("date"),
|
||||
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
|
||||
)
|
||||
.select_from(Message)
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
|
||||
)
|
||||
|
||||
if start_datetime_utc:
|
||||
stmt = stmt.where(Message.created_at >= start_datetime_utc)
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if end_datetime_utc:
|
||||
stmt = stmt.where(Message.created_at < end_datetime_utc)
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
stmt = stmt.group_by("date").order_by("date")
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(stmt, {"tz": account.timezone})
|
||||
for row in rs:
|
||||
response_data.append({"date": str(row.date), "conversation_count": row.conversation_count})
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
|
@ -161,8 +164,9 @@ class DailyTerminalsStatistic(Resource):
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
|
||||
FROM
|
||||
messages
|
||||
|
|
@ -217,8 +221,9 @@ class DailyTokenCostStatistic(Resource):
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
|
||||
SUM(total_price) AS total_price
|
||||
FROM
|
||||
|
|
@ -276,8 +281,9 @@ class AverageSessionInteractionStatistic(Resource):
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
AVG(subquery.message_count) AS interactions
|
||||
FROM
|
||||
(
|
||||
|
|
@ -351,8 +357,9 @@ class UserSatisfactionRateStatistic(Resource):
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(m.id) AS message_count,
|
||||
COUNT(mf.id) AS feedback_count
|
||||
FROM
|
||||
|
|
@ -416,8 +423,9 @@ class AverageResponseTimeStatistic(Resource):
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
AVG(provider_response_latency) AS latency
|
||||
FROM
|
||||
messages
|
||||
|
|
@ -471,8 +479,9 @@ class TokensPerSecondStatistic(Resource):
|
|||
account, _ = current_account_with_tenant()
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
CASE
|
||||
WHEN SUM(provider_response_latency) = 0 THEN 0
|
||||
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
||||
|
|
|
|||
|
|
@ -983,8 +983,9 @@ class DraftWorkflowTriggerRunApi(Resource):
|
|||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"node_id", type=str, required=True, location="json", nullable=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
workflow_service = WorkflowService()
|
||||
|
|
@ -1136,8 +1137,9 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
|||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_ids", type=list, required=True, location="json", nullable=False)
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"node_ids", type=list, required=True, location="json", nullable=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
node_ids = args["node_ids"]
|
||||
workflow_service = WorkflowService()
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
import logging
|
||||
from typing import NoReturn
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import NoReturn, ParamSpec, TypeVar
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
|
|
@ -21,8 +22,8 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIAB
|
|||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode
|
||||
from libs.login import login_required
|
||||
from models import App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
|
@ -140,8 +141,11 @@ _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
|||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
def _api_prerequisite(f):
|
||||
|
||||
def _api_prerequisite(f: Callable[P, R]):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
|
|
@ -155,11 +159,10 @@ def _api_prerequisite(f):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def wrapper(*args, **kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
|
@ -167,6 +170,7 @@ def _api_prerequisite(f):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
|
||||
class WorkflowVariableCollectionApi(Resource):
|
||||
@api.expect(_create_pagination_parser())
|
||||
@api.doc("get_workflow_variables")
|
||||
@api.doc(description="Get draft workflow variables")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ import logging
|
|||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
|
|
@ -29,8 +29,7 @@ class WebhookTriggerApi(Resource):
|
|||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
node_id = str(args["node_id"])
|
||||
|
|
@ -95,19 +94,19 @@ class AppTriggerEnableApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
trigger_id = args["trigger_id"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
|
|
@ -39,12 +39,10 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
|
|
@ -65,12 +63,10 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
def delete(self, binding_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import logging
|
|||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
|
@ -42,11 +42,9 @@ class OAuthDataSource(Resource):
|
|||
)
|
||||
@api.response(400, "Invalid provider")
|
||||
@api.response(403, "Admin privileges required")
|
||||
@is_admin_or_owner_required
|
||||
def get(self, provider: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
import base64
|
||||
|
||||
from controllers.console import console_ns
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -41,3 +44,37 @@ class Invoices(Resource):
|
|||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_invoices(current_user.email, current_tenant_id)
|
||||
|
||||
|
||||
@console_ns.route("/billing/partners/<string:partner_key>/tenants")
|
||||
class PartnerTenants(Resource):
|
||||
@api.doc("sync_partner_tenants_bindings")
|
||||
@api.doc(description="Sync partner tenants bindings")
|
||||
@api.doc(params={"partner_key": "Partner key"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"SyncPartnerTenantsBindingsRequest",
|
||||
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Tenants synced to partner successfully")
|
||||
@api.response(400, "Invalid partner information")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def put(self, partner_key: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
click_id = args["click_id"]
|
||||
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
|
||||
except Exception:
|
||||
raise BadRequest("Invalid partner_key")
|
||||
|
||||
if not click_id or not decoded_partner_key or not current_user.id:
|
||||
raise BadRequest("Invalid partner information")
|
||||
|
||||
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from controllers.console.wraps import (
|
|||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
|
|
@ -753,13 +754,11 @@ class DatasetApiKeyApi(Resource):
|
|||
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_fields)
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
|
|
@ -794,15 +793,11 @@ class DatasetApiDeleteApi(Resource):
|
|||
@api.response(204, "API key deleted successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, api_key_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
api_key_id = str(api_key_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
.where(
|
||||
|
|
|
|||
|
|
@ -162,6 +162,7 @@ class DatasetDocumentListApi(Resource):
|
|||
"keyword": "Search keyword",
|
||||
"sort": "Sort order (default: -created_at)",
|
||||
"fetch": "Fetch full details (default: false)",
|
||||
"status": "Filter documents by display status",
|
||||
}
|
||||
)
|
||||
@api.response(200, "Documents retrieved successfully")
|
||||
|
|
@ -175,6 +176,7 @@ class DatasetDocumentListApi(Resource):
|
|||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
sort = request.args.get("sort", default="-created_at", type=str)
|
||||
status = request.args.get("status", default=None, type=str)
|
||||
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
|
||||
try:
|
||||
fetch_val = request.args.get("fetch", default="false")
|
||||
|
|
@ -203,6 +205,9 @@ class DatasetDocumentListApi(Resource):
|
|||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
|
||||
|
||||
if status:
|
||||
query = DocumentService.apply_display_status_filter(query, status)
|
||||
|
||||
if search:
|
||||
search = f"%{search}%"
|
||||
query = query.where(Document.name.like(search))
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
|||
import services
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.dataset_service import DatasetService
|
||||
|
|
@ -200,12 +200,10 @@ class ExternalDatasetCreateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_restx import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
|
|
@ -12,17 +12,21 @@ from models import Account
|
|||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class Parser(BaseModel):
|
||||
inputs: dict
|
||||
datasource_type: str
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@api.expect(parser)
|
||||
@api.expect(console_ns.models[Parser.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -34,15 +38,10 @@ class DataSourceContentPreviewApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
args = Parser.model_validate(api.payload)
|
||||
|
||||
inputs = args.inputs
|
||||
datasource_type = args.datasource_type
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
preview_content = rag_pipeline_service.run_datasource_node_preview(
|
||||
pipeline=pipeline,
|
||||
|
|
@ -51,6 +50,6 @@ class DataSourceContentPreviewApi(Resource):
|
|||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
is_published=True,
|
||||
credential_id=args.get("credential_id"),
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
return preview_content, 200
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -21,12 +21,11 @@ class RagPipelineImportApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
|
|
@ -71,12 +70,10 @@ class RagPipelineImportConfirmApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self, import_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# Check user role first
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -98,12 +95,9 @@ class RagPipelineImportCheckDependenciesApi(Resource):
|
|||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
|
@ -117,12 +111,9 @@ class RagPipelineExportApi(Resource):
|
|||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, pipeline: Pipeline):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -191,6 +191,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
|
|
@ -198,8 +199,6 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_run.parse_args()
|
||||
|
||||
|
|
@ -235,6 +234,7 @@ class DraftRagPipelineRunApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
"""
|
||||
|
|
@ -242,8 +242,6 @@ class DraftRagPipelineRunApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_draft_run.parse_args()
|
||||
|
||||
|
|
@ -279,6 +277,7 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
"""
|
||||
|
|
@ -286,8 +285,6 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_published_run.parse_args()
|
||||
|
||||
|
|
@ -404,6 +401,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
|
|
@ -411,8 +409,6 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_rag_run.parse_args()
|
||||
|
||||
|
|
@ -444,6 +440,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||
@api.expect(parser_rag_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
|
|
@ -452,8 +449,6 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_rag_run.parse_args()
|
||||
|
||||
|
|
@ -490,6 +485,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||
@api.expect(parser_run_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_run_node_execution_fields)
|
||||
|
|
@ -499,8 +495,6 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_run_api.parse_args()
|
||||
|
||||
|
|
@ -523,6 +517,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||
class RagPipelineTaskStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, task_id: str):
|
||||
|
|
@ -531,8 +526,6 @@ class RagPipelineTaskStopApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
|
|
@ -544,6 +537,7 @@ class PublishedRagPipelineApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
|
|
@ -551,9 +545,6 @@ class PublishedRagPipelineApi(Resource):
|
|||
Get published pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
if not pipeline.is_published:
|
||||
return None
|
||||
# fetch published workflow by pipeline
|
||||
|
|
@ -566,6 +557,7 @@ class PublishedRagPipelineApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
"""
|
||||
|
|
@ -573,9 +565,6 @@ class PublishedRagPipelineApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
with Session(db.engine) as session:
|
||||
pipeline = session.merge(pipeline)
|
||||
|
|
@ -602,16 +591,12 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# Get default block configs
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
return rag_pipeline_service.get_default_block_configs()
|
||||
|
|
@ -626,16 +611,12 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline, block_type: str):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_default.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
|
|
@ -667,6 +648,7 @@ class PublishedAllRagPipelineApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_pagination_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
|
|
@ -674,8 +656,6 @@ class PublishedAllRagPipelineApi(Resource):
|
|||
Get published workflows
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_wf.parse_args()
|
||||
page = args["page"]
|
||||
|
|
@ -720,6 +700,7 @@ class RagPipelineByIdApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_fields)
|
||||
def patch(self, pipeline: Pipeline, workflow_id: str):
|
||||
|
|
@ -728,8 +709,6 @@ class RagPipelineByIdApi(Resource):
|
|||
"""
|
||||
# Check permission
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_wf_id.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from flask_restx import Resource, marshal_with, reqparse
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import Tag
|
||||
|
|
@ -91,12 +91,9 @@ class TagUpdateDeleteApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, tag_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -31,11 +30,10 @@ class EndpointCreateApi(Resource):
|
|||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
|
|
@ -168,6 +166,7 @@ class EndpointDeleteApi(Resource):
|
|||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
|
@ -175,9 +174,6 @@ class EndpointDeleteApi(Resource):
|
|||
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
return {
|
||||
|
|
@ -207,6 +203,7 @@ class EndpointUpdateApi(Resource):
|
|||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
|
@ -223,9 +220,6 @@ class EndpointUpdateApi(Resource):
|
|||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -252,6 +246,7 @@ class EndpointEnableApi(Resource):
|
|||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
|
@ -261,9 +256,6 @@ class EndpointEnableApi(Resource):
|
|||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
|
@ -284,6 +276,7 @@ class EndpointDisableApi(Resource):
|
|||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
|
@ -293,9 +286,6 @@ class EndpointDisableApi(Resource):
|
|||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,10 +2,9 @@ import io
|
|||
|
||||
from flask import send_file
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -85,12 +84,10 @@ class ModelProviderCredentialApi(Resource):
|
|||
@api.expect(parser_post_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_post_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
|
@ -110,11 +107,10 @@ class ModelProviderCredentialApi(Resource):
|
|||
@api.expect(parser_put_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_put_cred.parse_args()
|
||||
|
||||
|
|
@ -136,12 +132,10 @@ class ModelProviderCredentialApi(Resource):
|
|||
@api.expect(parser_delete_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_delete_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
|
@ -162,11 +156,10 @@ class ModelProviderCredentialSwitchApi(Resource):
|
|||
@api.expect(parser_switch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_switch.parse_args()
|
||||
|
||||
service = ModelProviderService()
|
||||
|
|
@ -250,11 +243,10 @@ class PreferredProviderTypeUpdateApi(Resource):
|
|||
@api.expect(parser_preferred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import logging
|
||||
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -50,12 +49,10 @@ class DefaultModelApi(Resource):
|
|||
@api.expect(parser_post_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_post_default.parse_args()
|
||||
model_provider_service = ModelProviderService()
|
||||
|
|
@ -133,13 +130,11 @@ class ModelProviderModelApi(Resource):
|
|||
@api.expect(parser_post_models)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
# To save the model's load balance configs
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = parser_post_models.parse_args()
|
||||
|
||||
if args.get("config_from", "") == "custom-model":
|
||||
|
|
@ -181,12 +176,10 @@ class ModelProviderModelApi(Resource):
|
|||
@api.expect(parser_delete_models)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_delete_models.parse_args()
|
||||
|
||||
|
|
@ -314,12 +307,10 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
@api.expect(parser_post_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_post_cred.parse_args()
|
||||
|
||||
|
|
@ -348,13 +339,10 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
@api.expect(parser_put_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_put_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
|
@ -377,12 +365,10 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
@api.expect(parser_delete_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = parser_delete_cred.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
|
@ -417,12 +403,11 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
|||
@api.expect(parser_switch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
args = parser_switch.parse_args()
|
||||
|
||||
service = ModelProviderService()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from werkzeug.exceptions import Forbidden
|
|||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -132,9 +132,11 @@ class PluginAssetApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
req.add_argument("file_name", type=str, required=True, location="args")
|
||||
req = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
.add_argument("file_name", type=str, required=True, location="args")
|
||||
)
|
||||
args = req.parse_args()
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
|
@ -619,13 +621,10 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
|||
@api.expect(parser_dynamic)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
# check if the user is admin or owner
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
|
||||
args = parser_dynamic.parse_args()
|
||||
|
|
@ -770,9 +769,11 @@ class PluginReadmeApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
parser.add_argument("language", type=str, required=False, location="args")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
.add_argument("language", type=str, required=False, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from controllers.console import api, console_ns
|
|||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
|
|
@ -115,11 +116,10 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
|||
@api.expect(parser_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_delete.parse_args()
|
||||
|
||||
|
|
@ -177,13 +177,10 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
|||
@api.expect(parser_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
args = parser_update.parse_args()
|
||||
|
|
@ -242,13 +239,11 @@ class ToolApiProviderAddApi(Resource):
|
|||
@api.expect(parser_api_add)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
args = parser_api_add.parse_args()
|
||||
|
|
@ -336,13 +331,11 @@ class ToolApiProviderUpdateApi(Resource):
|
|||
@api.expect(parser_api_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
args = parser_api_update.parse_args()
|
||||
|
|
@ -372,13 +365,11 @@ class ToolApiProviderDeleteApi(Resource):
|
|||
@api.expect(parser_api_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
args = parser_api_delete.parse_args()
|
||||
|
|
@ -496,13 +487,11 @@ class ToolWorkflowProviderCreateApi(Resource):
|
|||
@api.expect(parser_create)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
args = parser_create.parse_args()
|
||||
|
|
@ -539,13 +528,10 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
|||
@api.expect(parser_workflow_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
args = parser_workflow_update.parse_args()
|
||||
|
|
@ -577,13 +563,11 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
|||
@api.expect(parser_workflow_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
args = parser_workflow_delete.parse_args()
|
||||
|
|
@ -734,18 +718,15 @@ class ToolLabelsApi(Resource):
|
|||
class ToolPluginOAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
|
||||
# todo check permission
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("no oauth available client config found for this tool provider")
|
||||
|
|
@ -856,14 +837,12 @@ class ToolOAuthCustomClient(Resource):
|
|||
@api.expect(parser_custom)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
args = parser_custom.parse_args()
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return BuiltinToolManageService.save_custom_oauth_client_params(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -1086,7 +1065,13 @@ class ToolMCPAuthApi(Resource):
|
|||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
try:
|
||||
auth_result = auth(provider_entity, args.get("authorization_code"))
|
||||
# Pass the extracted OAuth metadata hints to auth()
|
||||
auth_result = auth(
|
||||
provider_entity,
|
||||
args.get("authorization_code"),
|
||||
resource_metadata_url=e.resource_metadata_url,
|
||||
scope_hint=e.scope_hint,
|
||||
)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
|
|
@ -1096,7 +1081,7 @@ class ToolMCPAuthApi(Resource):
|
|||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except MCPError as e:
|
||||
except (MCPError, ValueError) as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from werkzeug.exceptions import BadRequest, Forbidden
|
|||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
|
|
@ -67,14 +67,12 @@ class TriggerProviderInfoApi(Resource):
|
|||
class TriggerSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -92,17 +90,16 @@ class TriggerSubscriptionListApi(Resource):
|
|||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"credential_type", type=str, required=False, nullable=True, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -133,18 +130,17 @@ class TriggerSubscriptionBuilderGetApi(Resource):
|
|||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Verify a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -173,15 +169,17 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
|
|
@ -223,24 +221,23 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
|
|||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
# Use atomic update_and_build to prevent race conditions
|
||||
|
|
@ -264,14 +261,12 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||
class TriggerSubscriptionDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, subscription_id: str):
|
||||
"""Delete a subscription instance"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -446,14 +441,12 @@ class TriggerOAuthCallbackApi(Resource):
|
|||
class TriggerOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
|
@ -493,18 +486,18 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
|
@ -524,14 +517,12 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ class TenantApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(tenant_fields)
|
||||
def get(self):
|
||||
def post(self):
|
||||
if request.path == "/info":
|
||||
logger.warning("Deprecated URL /info was used.")
|
||||
|
||||
|
|
|
|||
|
|
@ -315,3 +315,19 @@ def edit_permission_required(f: Callable[P, R]):
|
|||
return f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
|
||||
def is_admin_or_owner_required(f: Callable[P, R]):
|
||||
@wraps(f)
|
||||
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
user = current_user._get_current_object()
|
||||
if not isinstance(user, Account) or not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
|
|
|||
|
|
@ -3,14 +3,12 @@ from typing import Literal
|
|||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx.api import HTTPStatus
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_fields, build_annotation_model
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
|
@ -161,14 +159,10 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@edit_permission_required
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
|
||||
def put(self, app_model: App, annotation_id):
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
annotation_id = str(annotation_id)
|
||||
args = annotation_create_parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
return annotation
|
||||
|
|
@ -185,13 +179,8 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def delete(self, app_model: App, annotation_id):
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, annotation_id: str):
|
||||
"""Delete an annotation."""
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
|
||||
return {"result": "success"}, 204
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from flask_restx import marshal, reqparse
|
|||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
from controllers.service_api.wraps import (
|
||||
|
|
@ -619,11 +620,9 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
@edit_permission_required
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
args = tag_delete_parser.parse_args()
|
||||
TagService.delete_tag(args["tag_id"])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
import json
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
|
@ -31,7 +34,7 @@ from fields.document_fields import document_fields, document_status_fields
|
|||
from libs.login import current_user
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from services.file_service import FileService
|
||||
|
||||
# Define parsers for document operations
|
||||
|
|
@ -51,15 +54,26 @@ document_text_create_parser = (
|
|||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
document_text_update_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("text", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class DocumentTextUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
text: str | None = None
|
||||
process_rule: ProcessRule | None = None
|
||||
doc_form: str = "text_model"
|
||||
doc_language: str = "English"
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_text_and_name(self) -> Self:
|
||||
if self.text is not None and self.name is None:
|
||||
raise ValueError("name is required when text is provided")
|
||||
return self
|
||||
|
||||
|
||||
for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
|
||||
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
|
|
@ -160,7 +174,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@service_api_ns.expect(document_text_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True)
|
||||
@service_api_ns.doc("update_document_by_text")
|
||||
@service_api_ns.doc(description="Update an existing document by providing text content")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
|
|
@ -173,12 +187,10 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text."""
|
||||
args = document_text_update_parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
|
@ -198,11 +210,9 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if args["text"]:
|
||||
if args.get("text"):
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both text and name must be strings.")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
|
|
@ -456,12 +466,16 @@ class DocumentListApi(DatasetApiResource):
|
|||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
status = request.args.get("status", default=None, type=str)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||
|
||||
if status:
|
||||
query = DocumentService.apply_display_status_filter(query, status)
|
||||
|
||||
if search:
|
||||
search = f"%{search}%"
|
||||
query = query.where(Document.name.like(search))
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ class LoginStatusApi(Resource):
|
|||
)
|
||||
def get(self):
|
||||
app_code = request.args.get("app_code")
|
||||
user_id = request.args.get("user_id")
|
||||
token = extract_webapp_access_token(request)
|
||||
if not app_code:
|
||||
return {
|
||||
|
|
@ -103,7 +104,7 @@ class LoginStatusApi(Resource):
|
|||
user_logged_in = False
|
||||
|
||||
try:
|
||||
_ = decode_jwt_token(app_code=app_code)
|
||||
_ = decode_jwt_token(app_code=app_code, user_id=user_id)
|
||||
app_logged_in = True
|
||||
except Exception:
|
||||
app_logged_in = False
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None =
|
|||
return decorator
|
||||
|
||||
|
||||
def decode_jwt_token(app_code: str | None = None):
|
||||
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
|
||||
system_features = FeatureService.get_system_features()
|
||||
if not app_code:
|
||||
app_code = str(request.headers.get(HEADER_NAME_APP_CODE))
|
||||
|
|
@ -63,6 +63,10 @@ def decode_jwt_token(app_code: str | None = None):
|
|||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
# Validate user_id against end_user's session_id if provided
|
||||
if user_id is not None and end_user.session_id != user_id:
|
||||
raise Unauthorized("Authentication has expired.")
|
||||
|
||||
# for enterprise webapp auth
|
||||
app_web_auth_enabled = False
|
||||
webapp_settings = None
|
||||
|
|
|
|||
|
|
@ -145,7 +145,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# for trigger debug run, not prepare user inputs
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
# trigger shouldn't prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
|
|
|
|||
|
|
@ -644,14 +644,15 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
if not workflow_run_id:
|
||||
return
|
||||
|
||||
workflow_app_log = WorkflowAppLog()
|
||||
workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
|
||||
workflow_app_log.workflow_id = self._workflow.id
|
||||
workflow_app_log.workflow_run_id = workflow_run_id
|
||||
workflow_app_log.created_from = created_from.value
|
||||
workflow_app_log.created_by_role = self._created_by_role
|
||||
workflow_app_log.created_by = self._user_id
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
created_from=created_from.value,
|
||||
created_by_role=self._created_by_role,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
|
||||
session.add(workflow_app_log)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -1,14 +1,10 @@
|
|||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Import InvokeFrom locally to avoid circular import
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
|
||||
class DatasourceRuntime(BaseModel):
|
||||
"""
|
||||
|
|
@ -17,7 +13,7 @@ class DatasourceRuntime(BaseModel):
|
|||
|
||||
tenant_id: str
|
||||
datasource_id: str | None = None
|
||||
invoke_from: Optional["InvokeFrom"] = None
|
||||
invoke_from: InvokeFrom | None = None
|
||||
datasource_invoke_from: DatasourceInvokeFrom | None = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ import secrets
|
|||
import urllib.parse
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from httpx import ConnectError, HTTPStatusError, RequestError
|
||||
import httpx
|
||||
from httpx import RequestError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
||||
|
|
@ -20,6 +21,7 @@ from core.mcp.types import (
|
|||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthTokens,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
|
@ -39,6 +41,131 @@ def generate_pkce_challenge() -> tuple[str, str]:
|
|||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
def build_protected_resource_metadata_discovery_urls(
|
||||
www_auth_resource_metadata_url: str | None, server_url: str
|
||||
) -> list[str]:
|
||||
"""
|
||||
Build a list of URLs to try for Protected Resource Metadata discovery.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
"""
|
||||
urls = []
|
||||
|
||||
# First priority: URL from WWW-Authenticate header
|
||||
if www_auth_resource_metadata_url:
|
||||
urls.append(www_auth_resource_metadata_url)
|
||||
|
||||
# Fallback: construct from server URL
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
||||
if fallback_url not in urls:
|
||||
urls.append(fallback_url)
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
|
||||
"""
|
||||
Build a list of URLs to try for OAuth Authorization Server Metadata discovery.
|
||||
|
||||
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
|
||||
|
||||
Per RFC 8414 section 3:
|
||||
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
|
||||
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
|
||||
|
||||
Example:
|
||||
- issuer: https://example.com/oauth
|
||||
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
|
||||
"""
|
||||
urls = []
|
||||
base_url = auth_server_url or server_url
|
||||
|
||||
parsed = urlparse(base_url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/") # Remove trailing slash
|
||||
|
||||
# Try OpenID Connect discovery first (more common)
|
||||
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
|
||||
|
||||
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
|
||||
# Include the path component if present in the issuer URL
|
||||
if path:
|
||||
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
|
||||
else:
|
||||
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def discover_protected_resource_metadata(
|
||||
prm_url: str | None, server_url: str, protocol_version: str | None = None
|
||||
) -> ProtectedResourceMetadata | None:
|
||||
"""Discover OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
|
||||
urls = build_protected_resource_metadata_discovery_urls(prm_url, server_url)
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
return ProtectedResourceMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def discover_oauth_authorization_server_metadata(
|
||||
auth_server_url: str | None, server_url: str, protocol_version: str | None = None
|
||||
) -> OAuthMetadata | None:
|
||||
"""Discover OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
|
||||
urls = build_oauth_authorization_server_metadata_discovery_urls(auth_server_url, server_url)
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_effective_scope(
|
||||
scope_from_www_auth: str | None,
|
||||
prm: ProtectedResourceMetadata | None,
|
||||
asm: OAuthMetadata | None,
|
||||
client_scope: str | None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Determine effective scope using priority-based selection strategy.
|
||||
|
||||
Priority order:
|
||||
1. WWW-Authenticate header scope (server explicit requirement)
|
||||
2. Protected Resource Metadata scopes
|
||||
3. OAuth Authorization Server Metadata scopes
|
||||
4. Client configured scope
|
||||
"""
|
||||
if scope_from_www_auth:
|
||||
return scope_from_www_auth
|
||||
|
||||
if prm and prm.scopes_supported:
|
||||
return " ".join(prm.scopes_supported)
|
||||
|
||||
if asm and asm.scopes_supported:
|
||||
return " ".join(asm.scopes_supported)
|
||||
|
||||
return client_scope
|
||||
|
||||
|
||||
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
|
||||
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
|
||||
# Generate a secure random state key
|
||||
|
|
@ -121,42 +248,36 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
|||
return False, ""
|
||||
|
||||
|
||||
def discover_oauth_metadata(server_url: str, protocol_version: str | None = None) -> OAuthMetadata | None:
|
||||
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
|
||||
# First check if the server supports OAuth 2.0 Resource Discovery
|
||||
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
|
||||
if support_resource_discovery:
|
||||
# The oauth_discovery_url is the authorization server base URL
|
||||
# Try OpenID Connect discovery first (more common), then OAuth 2.0
|
||||
urls_to_try = [
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
|
||||
]
|
||||
else:
|
||||
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
|
||||
def discover_oauth_metadata(
|
||||
server_url: str,
|
||||
resource_metadata_url: str | None = None,
|
||||
scope_hint: str | None = None,
|
||||
protocol_version: str | None = None,
|
||||
) -> tuple[OAuthMetadata | None, ProtectedResourceMetadata | None, str | None]:
|
||||
"""
|
||||
Discover OAuth metadata using RFC 8414/9470 standards.
|
||||
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
Args:
|
||||
server_url: The MCP server URL
|
||||
resource_metadata_url: Protected Resource Metadata URL from WWW-Authenticate header
|
||||
scope_hint: Scope hint from WWW-Authenticate header
|
||||
protocol_version: MCP protocol version
|
||||
|
||||
for url in urls_to_try:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
continue
|
||||
if not response.is_success:
|
||||
response.raise_for_status()
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
except (RequestError, HTTPStatusError) as e:
|
||||
if isinstance(e, ConnectError):
|
||||
response = ssrf_proxy.get(url)
|
||||
if response.status_code == 404:
|
||||
continue # Try next URL
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
# For other errors, try next URL
|
||||
continue
|
||||
Returns:
|
||||
(oauth_metadata, protected_resource_metadata, scope_hint)
|
||||
"""
|
||||
# Discover Protected Resource Metadata
|
||||
prm = discover_protected_resource_metadata(resource_metadata_url, server_url, protocol_version)
|
||||
|
||||
return None # No metadata found
|
||||
# Get authorization server URL from PRM or use server URL
|
||||
auth_server_url = None
|
||||
if prm and prm.authorization_servers:
|
||||
auth_server_url = prm.authorization_servers[0]
|
||||
|
||||
# Discover OAuth Authorization Server Metadata
|
||||
asm = discover_oauth_authorization_server_metadata(auth_server_url, server_url, protocol_version)
|
||||
|
||||
return asm, prm, scope_hint
|
||||
|
||||
|
||||
def start_authorization(
|
||||
|
|
@ -166,6 +287,7 @@ def start_authorization(
|
|||
redirect_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
scope: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Begins the authorization flow with secure Redis state storage."""
|
||||
response_type = "code"
|
||||
|
|
@ -175,13 +297,6 @@ def start_authorization(
|
|||
authorization_url = metadata.authorization_endpoint
|
||||
if response_type not in metadata.response_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
|
||||
if (
|
||||
not metadata.code_challenge_methods_supported
|
||||
or code_challenge_method not in metadata.code_challenge_methods_supported
|
||||
):
|
||||
raise ValueError(
|
||||
f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
|
||||
)
|
||||
else:
|
||||
authorization_url = urljoin(server_url, "/authorize")
|
||||
|
||||
|
|
@ -210,10 +325,49 @@ def start_authorization(
|
|||
"state": state_key,
|
||||
}
|
||||
|
||||
# Add scope if provided
|
||||
if scope:
|
||||
params["scope"] = scope
|
||||
|
||||
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
return authorization_url, code_verifier
|
||||
|
||||
|
||||
def _parse_token_response(response: httpx.Response) -> OAuthTokens:
|
||||
"""
|
||||
Parse OAuth token response supporting both JSON and form-urlencoded formats.
|
||||
|
||||
Per RFC 6749 Section 5.1, the standard format is JSON.
|
||||
However, some legacy OAuth providers (e.g., early GitHub OAuth Apps) return
|
||||
application/x-www-form-urlencoded format for backwards compatibility.
|
||||
|
||||
Args:
|
||||
response: The HTTP response from token endpoint
|
||||
|
||||
Returns:
|
||||
Parsed OAuth tokens
|
||||
|
||||
Raises:
|
||||
ValueError: If response cannot be parsed
|
||||
"""
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
|
||||
if "application/json" in content_type:
|
||||
# Standard OAuth 2.0 JSON response (RFC 6749)
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
elif "application/x-www-form-urlencoded" in content_type:
|
||||
# Legacy form-urlencoded response (non-standard but used by some providers)
|
||||
token_data = dict(urllib.parse.parse_qsl(response.text))
|
||||
return OAuthTokens.model_validate(token_data)
|
||||
else:
|
||||
# No content-type or unknown - try JSON first, fallback to form-urlencoded
|
||||
try:
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
except (ValidationError, json.JSONDecodeError):
|
||||
token_data = dict(urllib.parse.parse_qsl(response.text))
|
||||
return OAuthTokens.model_validate(token_data)
|
||||
|
||||
|
||||
def exchange_authorization(
|
||||
server_url: str,
|
||||
metadata: OAuthMetadata | None,
|
||||
|
|
@ -246,7 +400,7 @@ def exchange_authorization(
|
|||
response = ssrf_proxy.post(token_url, data=params)
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
return _parse_token_response(response)
|
||||
|
||||
|
||||
def refresh_authorization(
|
||||
|
|
@ -279,7 +433,7 @@ def refresh_authorization(
|
|||
raise MCPRefreshTokenError(e) from e
|
||||
if not response.is_success:
|
||||
raise MCPRefreshTokenError(response.text)
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
return _parse_token_response(response)
|
||||
|
||||
|
||||
def client_credentials_flow(
|
||||
|
|
@ -322,7 +476,7 @@ def client_credentials_flow(
|
|||
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
|
||||
)
|
||||
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
return _parse_token_response(response)
|
||||
|
||||
|
||||
def register_client(
|
||||
|
|
@ -352,6 +506,8 @@ def auth(
|
|||
provider: MCPProviderEntity,
|
||||
authorization_code: str | None = None,
|
||||
state_param: str | None = None,
|
||||
resource_metadata_url: str | None = None,
|
||||
scope_hint: str | None = None,
|
||||
) -> AuthResult:
|
||||
"""
|
||||
Orchestrates the full auth flow with a server using secure Redis state storage.
|
||||
|
|
@ -363,18 +519,26 @@ def auth(
|
|||
provider: The MCP provider entity
|
||||
authorization_code: Optional authorization code from OAuth callback
|
||||
state_param: Optional state parameter from OAuth callback
|
||||
resource_metadata_url: Optional Protected Resource Metadata URL from WWW-Authenticate
|
||||
scope_hint: Optional scope hint from WWW-Authenticate header
|
||||
|
||||
Returns:
|
||||
AuthResult containing actions to be performed and response data
|
||||
"""
|
||||
actions: list[AuthAction] = []
|
||||
server_url = provider.decrypt_server_url()
|
||||
server_metadata = discover_oauth_metadata(server_url)
|
||||
|
||||
# Discover OAuth metadata using RFC 8414/9470 standards
|
||||
server_metadata, prm, scope_from_www_auth = discover_oauth_metadata(
|
||||
server_url, resource_metadata_url, scope_hint, LATEST_PROTOCOL_VERSION
|
||||
)
|
||||
|
||||
client_metadata = provider.client_metadata
|
||||
provider_id = provider.id
|
||||
tenant_id = provider.tenant_id
|
||||
client_information = provider.retrieve_client_information()
|
||||
redirect_url = provider.redirect_url
|
||||
credentials = provider.decrypt_credentials()
|
||||
|
||||
# Determine grant type based on server metadata
|
||||
if not server_metadata:
|
||||
|
|
@ -392,8 +556,8 @@ def auth(
|
|||
else:
|
||||
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
# Get stored credentials
|
||||
credentials = provider.decrypt_credentials()
|
||||
# Determine effective scope using priority-based strategy
|
||||
effective_scope = get_effective_scope(scope_from_www_auth, prm, server_metadata, credentials.get("scope"))
|
||||
|
||||
if not client_information:
|
||||
if authorization_code is not None:
|
||||
|
|
@ -425,12 +589,11 @@ def auth(
|
|||
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||
# Direct token request without user interaction
|
||||
try:
|
||||
scope = credentials.get("scope")
|
||||
tokens = client_credentials_flow(
|
||||
server_url,
|
||||
server_metadata,
|
||||
client_information,
|
||||
scope,
|
||||
effective_scope,
|
||||
)
|
||||
|
||||
# Return action to save tokens and grant type
|
||||
|
|
@ -526,6 +689,7 @@ def auth(
|
|||
redirect_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
effective_scope,
|
||||
)
|
||||
|
||||
# Return action to save code verifier
|
||||
|
|
|
|||
|
|
@ -90,7 +90,13 @@ class MCPClientWithAuthRetry(MCPClient):
|
|||
mcp_service = MCPToolManageService(session=session)
|
||||
|
||||
# Perform authentication using the service's auth method
|
||||
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
|
||||
# Extract OAuth metadata hints from the error
|
||||
mcp_service.auth_with_actions(
|
||||
self.provider_entity,
|
||||
self.authorization_code,
|
||||
resource_metadata_url=error.resource_metadata_url,
|
||||
scope_hint=error.scope_hint,
|
||||
)
|
||||
|
||||
# Retrieve new tokens
|
||||
self.provider_entity = mcp_service.get_provider_entity(
|
||||
|
|
|
|||
|
|
@ -290,7 +290,7 @@ def sse_client(
|
|||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPAuthError(response=exc.response)
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,10 @@
|
|||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
|
||||
|
||||
class MCPError(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -7,7 +14,49 @@ class MCPConnectionError(MCPError):
|
|||
|
||||
|
||||
class MCPAuthError(MCPConnectionError):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
message: str | None = None,
|
||||
response: "httpx.Response | None" = None,
|
||||
www_authenticate_header: str | None = None,
|
||||
):
|
||||
"""
|
||||
MCP Authentication Error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
response: HTTP response object (will extract WWW-Authenticate header if provided)
|
||||
www_authenticate_header: Pre-extracted WWW-Authenticate header value
|
||||
"""
|
||||
super().__init__(message or "Authentication failed")
|
||||
|
||||
# Extract OAuth metadata hints from WWW-Authenticate header
|
||||
if response is not None:
|
||||
www_authenticate_header = response.headers.get("WWW-Authenticate")
|
||||
|
||||
self.resource_metadata_url: str | None = None
|
||||
self.scope_hint: str | None = None
|
||||
|
||||
if www_authenticate_header:
|
||||
self.resource_metadata_url = self._extract_field(www_authenticate_header, "resource_metadata")
|
||||
self.scope_hint = self._extract_field(www_authenticate_header, "scope")
|
||||
|
||||
@staticmethod
|
||||
def _extract_field(www_auth: str, field_name: str) -> str | None:
|
||||
"""Extract a specific field from the WWW-Authenticate header."""
|
||||
# Pattern to match field="value" or field=value
|
||||
pattern = rf'{field_name}="([^"]*)"'
|
||||
match = re.search(pattern, www_auth)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Try without quotes
|
||||
pattern = rf"{field_name}=([^\s,]+)"
|
||||
match = re.search(pattern, www_auth)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MCPRefreshTokenError(MCPError):
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ class BaseSession(
|
|||
messages when entered.
|
||||
"""
|
||||
|
||||
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
|
||||
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError]]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
_receive_request_type: type[ReceiveRequestT]
|
||||
|
|
@ -230,7 +230,7 @@ class BaseSession(
|
|||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
|
||||
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
|
||||
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError | HTTPStatusError] = queue.Queue()
|
||||
self._response_streams[request_id] = response_queue
|
||||
|
||||
try:
|
||||
|
|
@ -261,11 +261,17 @@ class BaseSession(
|
|||
message="No response received",
|
||||
)
|
||||
)
|
||||
elif isinstance(response_or_error, HTTPStatusError):
|
||||
# HTTPStatusError from streamable_client with preserved response object
|
||||
if response_or_error.response.status_code == 401:
|
||||
raise MCPAuthError(response=response_or_error.response)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.response.status_code, message=str(response_or_error))
|
||||
)
|
||||
elif isinstance(response_or_error, JSONRPCError):
|
||||
if response_or_error.error.code == 401:
|
||||
raise MCPAuthError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
)
|
||||
raise MCPAuthError(message=response_or_error.error.message)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
|
|
@ -327,13 +333,17 @@ class BaseSession(
|
|||
if isinstance(message, HTTPStatusError):
|
||||
response_queue = self._response_streams.get(self._request_id - 1)
|
||||
if response_queue is not None:
|
||||
response_queue.put(
|
||||
JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=self._request_id - 1,
|
||||
error=ErrorData(code=message.response.status_code, message=message.args[0]),
|
||||
# For 401 errors, pass the HTTPStatusError directly to preserve response object
|
||||
if message.response.status_code == 401:
|
||||
response_queue.put(message)
|
||||
else:
|
||||
response_queue.put(
|
||||
JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=self._request_id - 1,
|
||||
error=ErrorData(code=message.response.status_code, message=message.args[0]),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
|
||||
elif isinstance(message, Exception):
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ for reference.
|
|||
not separate types in the schema.
|
||||
"""
|
||||
# Client support both version, not support 2025-06-18 yet.
|
||||
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
||||
LATEST_PROTOCOL_VERSION = "2025-06-18"
|
||||
# Server support 2024-11-05 to allow claude to use.
|
||||
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
|
||||
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
|
||||
|
|
@ -1331,3 +1331,13 @@ class OAuthMetadata(BaseModel):
|
|||
response_types_supported: list[str]
|
||||
grant_types_supported: list[str] | None = None
|
||||
code_challenge_methods_supported: list[str] | None = None
|
||||
scopes_supported: list[str] | None = None
|
||||
|
||||
|
||||
class ProtectedResourceMetadata(BaseModel):
|
||||
"""OAuth 2.0 Protected Resource Metadata (RFC 9470)."""
|
||||
|
||||
resource: str | None = None
|
||||
authorization_servers: list[str]
|
||||
scopes_supported: list[str] | None = None
|
||||
bearer_methods_supported: list[str] | None = None
|
||||
|
|
|
|||
|
|
@ -1,12 +1,20 @@
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
import wandb
|
||||
import weave
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from weave.trace_server.trace_server_interface import (
|
||||
CallEndReq,
|
||||
CallStartReq,
|
||||
EndedCallSchemaForInsert,
|
||||
StartedCallSchemaForInsert,
|
||||
SummaryInsertMap,
|
||||
TraceStatus,
|
||||
)
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
|
|
@ -57,6 +65,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.calls: dict[str, Any] = {}
|
||||
self.project_id = f"{self.weave_client.entity}/{self.weave_client.project}"
|
||||
|
||||
def get_project_url(
|
||||
self,
|
||||
|
|
@ -424,6 +433,13 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
logger.debug("Weave API check failed: %s", str(e))
|
||||
raise ValueError(f"Weave API check failed: {str(e)}")
|
||||
|
||||
def _normalize_time(self, dt: datetime | None) -> datetime:
|
||||
if dt is None:
|
||||
return datetime.now(UTC)
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=UTC)
|
||||
return dt
|
||||
|
||||
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
|
||||
inputs = run_data.inputs
|
||||
if inputs is None:
|
||||
|
|
@ -437,19 +453,71 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
elif not isinstance(attributes, dict):
|
||||
attributes = {"attributes": str(attributes)}
|
||||
|
||||
call = self.weave_client.create_call(
|
||||
op=run_data.op,
|
||||
inputs=inputs,
|
||||
attributes=attributes,
|
||||
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
|
||||
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
|
||||
trace_id = attributes.get("trace_id") if isinstance(attributes, dict) else None
|
||||
if trace_id is None:
|
||||
trace_id = run_data.id
|
||||
|
||||
call_start_req = CallStartReq(
|
||||
start=StartedCallSchemaForInsert(
|
||||
project_id=self.project_id,
|
||||
id=run_data.id,
|
||||
op_name=str(run_data.op),
|
||||
trace_id=trace_id,
|
||||
parent_id=parent_run_id,
|
||||
started_at=started_at,
|
||||
attributes=attributes,
|
||||
inputs=inputs,
|
||||
wb_user_id=None,
|
||||
)
|
||||
)
|
||||
self.calls[run_data.id] = call
|
||||
if parent_run_id:
|
||||
self.calls[run_data.id].parent_id = parent_run_id
|
||||
self.weave_client.server.call_start(call_start_req)
|
||||
self.calls[run_data.id] = {"trace_id": trace_id, "parent_id": parent_run_id}
|
||||
|
||||
def finish_call(self, run_data: WeaveTraceModel):
|
||||
call = self.calls.get(run_data.id)
|
||||
if call:
|
||||
exception = Exception(run_data.exception) if run_data.exception else None
|
||||
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=exception)
|
||||
else:
|
||||
call_meta = self.calls.get(run_data.id)
|
||||
if not call_meta:
|
||||
raise ValueError(f"Call with id {run_data.id} not found")
|
||||
|
||||
attributes = run_data.attributes
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
elif not isinstance(attributes, dict):
|
||||
attributes = {"attributes": str(attributes)}
|
||||
|
||||
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
|
||||
end_time = attributes.get("end_time") if isinstance(attributes, dict) else None
|
||||
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
|
||||
ended_at = self._normalize_time(end_time if isinstance(end_time, datetime) else None)
|
||||
elapsed_ms = int((ended_at - started_at).total_seconds() * 1000)
|
||||
if elapsed_ms < 0:
|
||||
elapsed_ms = 0
|
||||
|
||||
status_counts = {
|
||||
TraceStatus.SUCCESS: 0,
|
||||
TraceStatus.ERROR: 0,
|
||||
}
|
||||
if run_data.exception:
|
||||
status_counts[TraceStatus.ERROR] = 1
|
||||
else:
|
||||
status_counts[TraceStatus.SUCCESS] = 1
|
||||
|
||||
summary: dict[str, Any] = {
|
||||
"status_counts": status_counts,
|
||||
"weave": {"latency_ms": elapsed_ms},
|
||||
}
|
||||
|
||||
exception_str = str(run_data.exception) if run_data.exception else None
|
||||
|
||||
call_end_req = CallEndReq(
|
||||
end=EndedCallSchemaForInsert(
|
||||
project_id=self.project_id,
|
||||
id=run_data.id,
|
||||
ended_at=ended_at,
|
||||
exception=exception_str,
|
||||
output=run_data.outputs,
|
||||
summary=cast(SummaryInsertMap, summary),
|
||||
)
|
||||
)
|
||||
self.weave_client.server.call_end(call_end_req)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,18 @@ logger = logging.getLogger(__name__)
|
|||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
T = TypeVar("T", bound="MatrixoneVector")
|
||||
|
||||
|
||||
def ensure_client(func: Callable[Concatenate[T, P], R]):
|
||||
@wraps(func)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
if self.client is None:
|
||||
self.client = self._get_client(None, False)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class MatrixoneConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
|
|
@ -206,19 +218,6 @@ class MatrixoneVector(BaseVector):
|
|||
self.client.delete()
|
||||
|
||||
|
||||
T = TypeVar("T", bound=MatrixoneVector)
|
||||
|
||||
|
||||
def ensure_client(func: Callable[Concatenate[T, P], R]):
|
||||
@wraps(func)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
if self.client is None:
|
||||
self.client = self._get_client(None, False)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class MatrixoneVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
|
||||
if dataset.index_struct_dict:
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@ from collections.abc import Generator, Mapping
|
|||
from typing import Any, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import Float, and_, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy import and_, or_, select
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
|
|
@ -1023,60 +1022,55 @@ class DatasetRetrieval:
|
|||
self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list
|
||||
):
|
||||
if value is None and condition not in ("empty", "not empty"):
|
||||
return
|
||||
return filters
|
||||
|
||||
json_field = DatasetDocument.doc_metadata[metadata_name].as_string()
|
||||
|
||||
key = f"{metadata_name}_{sequence}"
|
||||
key_value = f"{metadata_name}_{sequence}_value"
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
|
||||
case "start with":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
|
||||
case "end with":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) == value)
|
||||
filters.append(json_field == value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() == value)
|
||||
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) != value)
|
||||
filters.append(json_field != value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() != value)
|
||||
|
||||
case "empty":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
|
||||
|
||||
case "not empty":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
|
||||
|
||||
case "before" | "<":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) < value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() < value)
|
||||
|
||||
case "after" | ">":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) > value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() > value)
|
||||
|
||||
case "≤" | "<=":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) <= value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() <= value)
|
||||
|
||||
case "≥" | ">=":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) >= value)
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value)
|
||||
case _:
|
||||
pass
|
||||
|
||||
return filters
|
||||
|
||||
def _fetch_model_config(
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
|
|||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
|
@ -32,7 +33,6 @@ from services.tools.mcp_tools_manage_service import MCPToolManageService
|
|||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
|
|
@ -63,7 +63,6 @@ from services.tools.tools_transform_service import ToolTransformService
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -618,12 +617,28 @@ class ToolManager:
|
|||
"""
|
||||
# according to multi credentials, select the one with is_default=True first, then created_at oldest
|
||||
# for compatibility with old version
|
||||
sql = """
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
# PostgreSQL: Use DISTINCT ON
|
||||
sql = """
|
||||
SELECT DISTINCT ON (tenant_id, provider) id
|
||||
FROM tool_builtin_providers
|
||||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
else:
|
||||
# MySQL: Use window function to achieve same result
|
||||
sql = """
|
||||
SELECT id FROM (
|
||||
SELECT id,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY tenant_id, provider
|
||||
ORDER BY is_default DESC, created_at DESC
|
||||
) as rn
|
||||
FROM tool_builtin_providers
|
||||
WHERE tenant_id = :tenant_id
|
||||
) ranked WHERE rn = 1
|
||||
"""
|
||||
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.file.models import File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class ArrayValidation(StrEnum):
|
||||
"""Strategy for validating array elements.
|
||||
|
|
@ -155,6 +158,17 @@ class SegmentType(StrEnum):
|
|||
return isinstance(value, File)
|
||||
elif self == SegmentType.NONE:
|
||||
return value is None
|
||||
elif self == SegmentType.GROUP:
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import Segment
|
||||
|
||||
if isinstance(value, SegmentGroup):
|
||||
return all(isinstance(item, Segment) for item in value.value)
|
||||
|
||||
if isinstance(value, list):
|
||||
return all(isinstance(item, Segment) for item in value)
|
||||
|
||||
return False
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ from collections import defaultdict
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import Float, and_, func, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
|
|
@ -597,79 +596,79 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
|
|||
if value is None and condition not in ("empty", "not empty"):
|
||||
return filters
|
||||
|
||||
key = f"{metadata_name}_{sequence}"
|
||||
key_value = f"{metadata_name}_{sequence}_value"
|
||||
json_field = Document.doc_metadata[metadata_name].as_string()
|
||||
|
||||
match condition:
|
||||
case "contains":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"%{value}%"))
|
||||
|
||||
case "not contains":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.notlike(f"%{value}%"))
|
||||
|
||||
case "start with":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"{value}%"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"{value}%"))
|
||||
|
||||
case "end with":
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
|
||||
**{key: metadata_name, key_value: f"%{value}"}
|
||||
)
|
||||
)
|
||||
filters.append(json_field.like(f"%{value}"))
|
||||
case "in":
|
||||
if isinstance(value, str):
|
||||
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
|
||||
escaped_value_str = ",".join(escaped_values)
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
escaped_value_str = str(value)
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} = any(string_to_array(:{key_value},','))")).params(
|
||||
**{key: metadata_name, key_value: escaped_value_str}
|
||||
)
|
||||
)
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(False))
|
||||
else:
|
||||
filters.append(json_field.in_(value_list))
|
||||
|
||||
case "not in":
|
||||
if isinstance(value, str):
|
||||
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
|
||||
escaped_value_str = ",".join(escaped_values)
|
||||
value_list = [v.strip() for v in value.split(",") if v.strip()]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_list = [str(v) for v in value if v is not None]
|
||||
else:
|
||||
escaped_value_str = str(value)
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} != all(string_to_array(:{key_value},','))")).params(
|
||||
**{key: metadata_name, key_value: escaped_value_str}
|
||||
)
|
||||
)
|
||||
case "=" | "is":
|
||||
value_list = [str(value)] if value is not None else []
|
||||
|
||||
if not value_list:
|
||||
filters.append(literal(True))
|
||||
else:
|
||||
filters.append(json_field.notin_(value_list))
|
||||
|
||||
case "is" | "=":
|
||||
if isinstance(value, str):
|
||||
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) == value)
|
||||
filters.append(json_field == value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() == value)
|
||||
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
|
||||
else:
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) != value)
|
||||
filters.append(json_field != value)
|
||||
elif isinstance(value, (int, float)):
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() != value)
|
||||
|
||||
case "empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].is_(None))
|
||||
|
||||
case "not empty":
|
||||
filters.append(Document.doc_metadata[metadata_name].isnot(None))
|
||||
|
||||
case "before" | "<":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) < value)
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() < value)
|
||||
|
||||
case "after" | ">":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) > value)
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() > value)
|
||||
|
||||
case "≤" | "<=":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) <= value)
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() <= value)
|
||||
|
||||
case "≥" | ">=":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) >= value)
|
||||
filters.append(Document.doc_metadata[metadata_name].as_float() >= value)
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
return filters
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
import importlib
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Mapping as TypingMapping
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
|
@ -100,8 +99,8 @@ class ResponseStreamCoordinatorProtocol(Protocol):
|
|||
class GraphProtocol(Protocol):
|
||||
"""Structural interface required from graph instances attached to the runtime state."""
|
||||
|
||||
nodes: TypingMapping[str, object]
|
||||
edges: TypingMapping[str, object]
|
||||
nodes: Mapping[str, object]
|
||||
edges: Mapping[str, object]
|
||||
root_node: object
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
|
||||
|
|
|
|||
|
|
@ -265,6 +265,45 @@ def _assert_not_empty(*, value: object) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _normalize_numeric_values(value: int | float, expected: object) -> tuple[int | float, int | float]:
|
||||
"""
|
||||
Normalize value and expected to compatible numeric types for comparison.
|
||||
|
||||
Args:
|
||||
value: The actual numeric value (int or float)
|
||||
expected: The expected value (int, float, or str)
|
||||
|
||||
Returns:
|
||||
A tuple of (normalized_value, normalized_expected) with compatible types
|
||||
|
||||
Raises:
|
||||
ValueError: If expected cannot be converted to a number
|
||||
"""
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to number")
|
||||
|
||||
# Convert expected to appropriate numeric type
|
||||
if isinstance(expected, str):
|
||||
# Try to convert to float first to handle decimal strings
|
||||
try:
|
||||
expected_float = float(expected)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Cannot convert '{expected}' to number") from e
|
||||
|
||||
# If value is int and expected is a whole number, keep as int comparison
|
||||
if isinstance(value, int) and expected_float.is_integer():
|
||||
return value, int(expected_float)
|
||||
else:
|
||||
# Otherwise convert value to float for comparison
|
||||
return float(value) if isinstance(value, int) else value, expected_float
|
||||
elif isinstance(expected, float):
|
||||
# If expected is already float, convert int value to float
|
||||
return float(value) if isinstance(value, int) else value, expected
|
||||
else:
|
||||
# expected is int
|
||||
return value, expected
|
||||
|
||||
|
||||
def _assert_equal(*, value: object, expected: object) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
|
@ -324,18 +363,8 @@ def _assert_greater_than(*, value: object, expected: object) -> bool:
|
|||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value <= expected:
|
||||
return False
|
||||
return True
|
||||
value, expected = _normalize_numeric_values(value, expected)
|
||||
return value > expected
|
||||
|
||||
|
||||
def _assert_less_than(*, value: object, expected: object) -> bool:
|
||||
|
|
@ -345,18 +374,8 @@ def _assert_less_than(*, value: object, expected: object) -> bool:
|
|||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value >= expected:
|
||||
return False
|
||||
return True
|
||||
value, expected = _normalize_numeric_values(value, expected)
|
||||
return value < expected
|
||||
|
||||
|
||||
def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool:
|
||||
|
|
@ -366,18 +385,8 @@ def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool:
|
|||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value < expected:
|
||||
return False
|
||||
return True
|
||||
value, expected = _normalize_numeric_values(value, expected)
|
||||
return value >= expected
|
||||
|
||||
|
||||
def _assert_less_than_or_equal(*, value: object, expected: object) -> bool:
|
||||
|
|
@ -387,18 +396,8 @@ def _assert_less_than_or_equal(*, value: object, expected: object) -> bool:
|
|||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value > expected:
|
||||
return False
|
||||
return True
|
||||
value, expected = _normalize_numeric_values(value, expected)
|
||||
return value <= expected
|
||||
|
||||
|
||||
def _assert_null(*, value: object) -> bool:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,209 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum, auto
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCharge:
|
||||
"""
|
||||
Result of a quota consumption operation.
|
||||
|
||||
Attributes:
|
||||
success: Whether the quota charge succeeded
|
||||
charge_id: UUID for refund, or None if failed/disabled
|
||||
"""
|
||||
|
||||
success: bool
|
||||
charge_id: str | None
|
||||
_quota_type: "QuotaType"
|
||||
|
||||
def refund(self) -> None:
|
||||
"""
|
||||
Refund this quota charge.
|
||||
|
||||
Safe to call even if charge failed or was disabled.
|
||||
This method guarantees no exceptions will be raised.
|
||||
"""
|
||||
if self.charge_id:
|
||||
self._quota_type.refund(self.charge_id)
|
||||
logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id)
|
||||
|
||||
|
||||
class QuotaType(StrEnum):
|
||||
"""
|
||||
Supported quota types for tenant feature usage.
|
||||
|
||||
Add additional types here whenever new billable features become available.
|
||||
"""
|
||||
|
||||
# Trigger execution quota
|
||||
TRIGGER = auto()
|
||||
|
||||
# Workflow execution quota
|
||||
WORKFLOW = auto()
|
||||
|
||||
UNLIMITED = auto()
|
||||
|
||||
@property
|
||||
def billing_key(self) -> str:
|
||||
"""
|
||||
Get the billing key for the feature.
|
||||
"""
|
||||
match self:
|
||||
case QuotaType.TRIGGER:
|
||||
return "trigger_event"
|
||||
case QuotaType.WORKFLOW:
|
||||
return "api_rate_limit"
|
||||
case _:
|
||||
raise ValueError(f"Invalid quota type: {self}")
|
||||
|
||||
def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge:
|
||||
"""
|
||||
Consume quota for the feature.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to consume (default: 1)
|
||||
|
||||
Returns:
|
||||
QuotaCharge with success status and charge_id for refund
|
||||
|
||||
Raises:
|
||||
QuotaExceededError: When quota is insufficient
|
||||
"""
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.debug("Billing disabled, allowing request for %s", tenant_id)
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=self)
|
||||
|
||||
logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id)
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to consume must be greater than 0")
|
||||
|
||||
try:
|
||||
response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount)
|
||||
|
||||
if response.get("result") != "success":
|
||||
logger.warning(
|
||||
"Failed to consume quota for %s, feature %s details: %s",
|
||||
tenant_id,
|
||||
self.value,
|
||||
response.get("detail"),
|
||||
)
|
||||
raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount)
|
||||
|
||||
charge_id = response.get("history_id")
|
||||
logger.debug(
|
||||
"Successfully consumed %d %s quota for tenant %s, charge_id: %s",
|
||||
amount,
|
||||
self.value,
|
||||
tenant_id,
|
||||
charge_id,
|
||||
)
|
||||
return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self)
|
||||
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
# fail-safe: allow request on billing errors
|
||||
logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value)
|
||||
return unlimited()
|
||||
|
||||
def check(self, tenant_id: str, amount: int = 1) -> bool:
|
||||
"""
|
||||
Check if tenant has sufficient quota without consuming.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
amount: Amount to check (default: 1)
|
||||
|
||||
Returns:
|
||||
True if quota is sufficient, False otherwise
|
||||
"""
|
||||
from configs import dify_config
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return True
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount to check must be greater than 0")
|
||||
|
||||
try:
|
||||
remaining = self.get_remaining(tenant_id)
|
||||
return remaining >= amount if remaining != -1 else True
|
||||
except Exception:
|
||||
logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value)
|
||||
# fail-safe: allow request on billing errors
|
||||
return True
|
||||
|
||||
def refund(self, charge_id: str) -> None:
|
||||
"""
|
||||
Refund quota using charge_id from consume().
|
||||
|
||||
This method guarantees no exceptions will be raised.
|
||||
All errors are logged but silently handled.
|
||||
|
||||
Args:
|
||||
charge_id: The UUID returned from consume()
|
||||
"""
|
||||
try:
|
||||
from configs import dify_config
|
||||
from services.billing_service import BillingService
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return
|
||||
|
||||
if not charge_id:
|
||||
logger.warning("Cannot refund: charge_id is empty")
|
||||
return
|
||||
|
||||
logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id)
|
||||
|
||||
response = BillingService.refund_tenant_feature_plan_usage(charge_id)
|
||||
if response.get("result") == "success":
|
||||
logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id)
|
||||
else:
|
||||
logger.warning("Refund failed for charge_id: %s", charge_id)
|
||||
|
||||
except Exception:
|
||||
# Catch ALL exceptions - refund must never fail
|
||||
logger.exception("Failed to refund quota for charge_id: %s", charge_id)
|
||||
# Don't raise - refund is best-effort and must be silent
|
||||
|
||||
def get_remaining(self, tenant_id: str) -> int:
|
||||
"""
|
||||
Get remaining quota for the tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
|
||||
Returns:
|
||||
Remaining quota amount
|
||||
"""
|
||||
from services.billing_service import BillingService
|
||||
|
||||
try:
|
||||
usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key)
|
||||
# Assuming the API returns a dict with 'remaining' or 'limit' and 'used'
|
||||
if isinstance(usage_info, dict):
|
||||
return usage_info.get("remaining", 0)
|
||||
# If it returns a simple number, treat it as remaining
|
||||
return int(usage_info) if usage_info else 0
|
||||
except Exception:
|
||||
logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value)
|
||||
return -1
|
||||
|
||||
|
||||
def unlimited() -> QuotaCharge:
|
||||
"""
|
||||
Return a quota charge for unlimited quota.
|
||||
|
||||
This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type.
|
||||
"""
|
||||
return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED)
|
||||
|
|
@ -10,7 +10,6 @@ from redis import RedisError
|
|||
from redis.cache import CacheConfig
|
||||
from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
from redis.lock import Lock
|
||||
from redis.sentinel import Sentinel
|
||||
|
||||
from configs import dify_config
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ class ClickZettaVolumeConfig(BaseModel):
|
|||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||
then fall back to CLICKZETTA_* environment variables (for vector DB config).
|
||||
"""
|
||||
import os
|
||||
|
||||
# Helper function to get environment variable with fallback
|
||||
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .channel import BroadcastChannel
|
||||
from .sharded_channel import ShardedRedisBroadcastChannel
|
||||
|
||||
__all__ = ["BroadcastChannel"]
|
||||
__all__ = ["BroadcastChannel", "ShardedRedisBroadcastChannel"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,205 @@
|
|||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import types
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Self
|
||||
|
||||
from libs.broadcast_channel.channel import Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis.client import PubSub
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisSubscriptionBase(Subscription):
|
||||
"""Base class for Redis pub/sub subscriptions with common functionality.
|
||||
|
||||
This class provides shared functionality for both regular and sharded
|
||||
Redis pub/sub subscriptions, reducing code duplication and improving
|
||||
maintainability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._pubsub: PubSub | None = pubsub
|
||||
self._topic = topic
|
||||
self._closed = threading.Event()
|
||||
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
|
||||
self._dropped_count = 0
|
||||
self._listener_thread: threading.Thread | None = None
|
||||
self._start_lock = threading.Lock()
|
||||
self._started = False
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
"""Start the subscription if not already started."""
|
||||
with self._start_lock:
|
||||
if self._started:
|
||||
return
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
if self._pubsub is None:
|
||||
raise SubscriptionClosedError(
|
||||
f"The Redis {self._get_subscription_type()} subscription has been cleaned up"
|
||||
)
|
||||
|
||||
self._subscribe()
|
||||
_logger.debug("Subscribed to %s channel %s", self._get_subscription_type(), self._topic)
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self._listen,
|
||||
name=f"redis-{self._get_subscription_type().replace(' ', '-')}-broadcast-{self._topic}",
|
||||
daemon=True,
|
||||
)
|
||||
self._listener_thread.start()
|
||||
self._started = True
|
||||
|
||||
def _listen(self) -> None:
|
||||
"""Main listener loop for processing messages."""
|
||||
pubsub = self._pubsub
|
||||
assert pubsub is not None, "PubSub should not be None while starting listening."
|
||||
while not self._closed.is_set():
|
||||
raw_message = self._get_message()
|
||||
|
||||
if raw_message is None:
|
||||
continue
|
||||
|
||||
if raw_message.get("type") != self._get_message_type():
|
||||
continue
|
||||
|
||||
channel_field = raw_message.get("channel")
|
||||
if isinstance(channel_field, bytes):
|
||||
channel_name = channel_field.decode("utf-8")
|
||||
elif isinstance(channel_field, str):
|
||||
channel_name = channel_field
|
||||
else:
|
||||
channel_name = str(channel_field)
|
||||
|
||||
if channel_name != self._topic:
|
||||
_logger.warning(
|
||||
"Ignoring %s message from unexpected channel %s", self._get_subscription_type(), channel_name
|
||||
)
|
||||
continue
|
||||
|
||||
payload_bytes: bytes | None = raw_message.get("data")
|
||||
if not isinstance(payload_bytes, bytes):
|
||||
_logger.error(
|
||||
"Received invalid data from %s channel %s, type=%s",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
type(payload_bytes),
|
||||
)
|
||||
continue
|
||||
|
||||
self._enqueue_message(payload_bytes)
|
||||
|
||||
_logger.debug("%s listener thread stopped for channel %s", self._get_subscription_type().title(), self._topic)
|
||||
self._unsubscribe()
|
||||
pubsub.close()
|
||||
_logger.debug("%s PubSub closed for topic %s", self._get_subscription_type().title(), self._topic)
|
||||
self._pubsub = None
|
||||
|
||||
def _enqueue_message(self, payload: bytes) -> None:
|
||||
"""Enqueue a message to the internal queue with dropping behavior."""
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
self._queue.put_nowait(payload)
|
||||
return
|
||||
except queue.Full:
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
self._dropped_count += 1
|
||||
_logger.debug(
|
||||
"Dropped message from Redis %s subscription, topic=%s, total_dropped=%d",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
self._dropped_count,
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
return
|
||||
|
||||
def _message_iterator(self) -> Generator[bytes, None, None]:
|
||||
"""Iterator for consuming messages from the subscription."""
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
item = self._queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
"""Return an iterator over messages from the subscription."""
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||
"""Receive the next message from the subscription."""
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
self._start_if_needed()
|
||||
|
||||
try:
|
||||
item = self._queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
return item
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""Context manager entry point."""
|
||||
self._start_if_needed()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
"""Context manager exit point."""
|
||||
self.close()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the subscription and clean up resources."""
|
||||
if self._closed.is_set():
|
||||
return
|
||||
|
||||
self._closed.set()
|
||||
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the
|
||||
# message retrieval method should NOT be called concurrently.
|
||||
#
|
||||
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||
listener = self._listener_thread
|
||||
if listener is not None:
|
||||
listener.join(timeout=1.0)
|
||||
self._listener_thread = None
|
||||
|
||||
# Abstract methods to be implemented by subclasses
|
||||
def _get_subscription_type(self) -> str:
|
||||
"""Return the subscription type (e.g., 'regular' or 'sharded')."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _subscribe(self) -> None:
|
||||
"""Subscribe to the Redis topic using the appropriate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _unsubscribe(self) -> None:
|
||||
"""Unsubscribe from the Redis topic using the appropriate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
"""Get a message from Redis using the appropriate method."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
"""Return the expected message type (e.g., 'message' or 'smessage')."""
|
||||
raise NotImplementedError
|
||||
|
|
@ -1,24 +1,15 @@
|
|||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import types
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Self
|
||||
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis import Redis
|
||||
from redis.client import PubSub
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
|
||||
class BroadcastChannel:
|
||||
"""
|
||||
Redis Pub/Sub based broadcast channel implementation.
|
||||
Redis Pub/Sub based broadcast channel implementation (regular, non-sharded).
|
||||
|
||||
Provides "at most once" delivery semantics for messages published to channels.
|
||||
Uses Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
|
||||
Provides "at most once" delivery semantics for messages published to channels
|
||||
using Redis PUBLISH/SUBSCRIBE commands for real-time message delivery.
|
||||
|
||||
The `redis_client` used to construct BroadcastChannel should have `decode_responses` set to `False`.
|
||||
"""
|
||||
|
|
@ -54,147 +45,23 @@ class Topic:
|
|||
)
|
||||
|
||||
|
||||
class _RedisSubscription(Subscription):
|
||||
def __init__(
|
||||
self,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._pubsub: PubSub | None = pubsub
|
||||
self._topic = topic
|
||||
self._closed = threading.Event()
|
||||
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=1024)
|
||||
self._dropped_count = 0
|
||||
self._listener_thread: threading.Thread | None = None
|
||||
self._start_lock = threading.Lock()
|
||||
self._started = False
|
||||
class _RedisSubscription(RedisSubscriptionBase):
|
||||
"""Regular Redis pub/sub subscription implementation."""
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
with self._start_lock:
|
||||
if self._started:
|
||||
return
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
if self._pubsub is None:
|
||||
raise SubscriptionClosedError("The Redis subscription has been cleaned up")
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "regular"
|
||||
|
||||
self._pubsub.subscribe(self._topic)
|
||||
_logger.debug("Subscribed to channel %s", self._topic)
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.subscribe(self._topic)
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self._listen,
|
||||
name=f"redis-broadcast-{self._topic}",
|
||||
daemon=True,
|
||||
)
|
||||
self._listener_thread.start()
|
||||
self._started = True
|
||||
def _unsubscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.unsubscribe(self._topic)
|
||||
|
||||
def _listen(self) -> None:
|
||||
pubsub = self._pubsub
|
||||
assert pubsub is not None, "PubSub should not be None while starting listening."
|
||||
while not self._closed.is_set():
|
||||
raw_message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
||||
def _get_message(self) -> dict | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
||||
|
||||
if raw_message is None:
|
||||
continue
|
||||
|
||||
if raw_message.get("type") != "message":
|
||||
continue
|
||||
|
||||
channel_field = raw_message.get("channel")
|
||||
if isinstance(channel_field, bytes):
|
||||
channel_name = channel_field.decode("utf-8")
|
||||
elif isinstance(channel_field, str):
|
||||
channel_name = channel_field
|
||||
else:
|
||||
channel_name = str(channel_field)
|
||||
|
||||
if channel_name != self._topic:
|
||||
_logger.warning("Ignoring message from unexpected channel %s", channel_name)
|
||||
continue
|
||||
|
||||
payload_bytes: bytes | None = raw_message.get("data")
|
||||
if not isinstance(payload_bytes, bytes):
|
||||
_logger.error("Received invalid data from channel %s, type=%s", self._topic, type(payload_bytes))
|
||||
continue
|
||||
|
||||
self._enqueue_message(payload_bytes)
|
||||
|
||||
_logger.debug("Listener thread stopped for channel %s", self._topic)
|
||||
pubsub.unsubscribe(self._topic)
|
||||
pubsub.close()
|
||||
_logger.debug("PubSub closed for topic %s", self._topic)
|
||||
self._pubsub = None
|
||||
|
||||
def _enqueue_message(self, payload: bytes) -> None:
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
self._queue.put_nowait(payload)
|
||||
return
|
||||
except queue.Full:
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
self._dropped_count += 1
|
||||
_logger.debug(
|
||||
"Dropped message from Redis subscription, topic=%s, total_dropped=%d",
|
||||
self._topic,
|
||||
self._dropped_count,
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
return
|
||||
|
||||
def _message_iterator(self) -> Generator[bytes, None, None]:
|
||||
while not self._closed.is_set():
|
||||
try:
|
||||
item = self._queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError("The Redis subscription is closed")
|
||||
self._start_if_needed()
|
||||
|
||||
try:
|
||||
item = self._queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
return item
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self._start_if_needed()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: types.TracebackType | None,
|
||||
) -> bool | None:
|
||||
self.close()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed.is_set():
|
||||
return
|
||||
|
||||
self._closed.set()
|
||||
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the `PubSub.get_message`
|
||||
# method should NOT be called concurrently.
|
||||
#
|
||||
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||
listener = self._listener_thread
|
||||
if listener is not None:
|
||||
listener.join(timeout=1.0)
|
||||
self._listener_thread = None
|
||||
def _get_message_type(self) -> str:
|
||||
return "message"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
|
||||
class ShardedRedisBroadcastChannel:
|
||||
"""
|
||||
Redis 7.0+ Sharded Pub/Sub based broadcast channel implementation.
|
||||
|
||||
Provides "at most once" delivery semantics using SPUBLISH/SSUBSCRIBE commands,
|
||||
distributing channels across Redis cluster nodes for better scalability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis,
|
||||
):
|
||||
self._client = redis_client
|
||||
|
||||
def topic(self, topic: str) -> "ShardedTopic":
|
||||
return ShardedTopic(self._client, topic)
|
||||
|
||||
|
||||
class ShardedTopic:
|
||||
def __init__(self, redis_client: Redis, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.spublish(self._topic, payload) # type: ignore[attr-defined]
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisShardedSubscription(
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
|
||||
|
||||
class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
"""Redis 7.0+ sharded pub/sub subscription implementation."""
|
||||
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "sharded"
|
||||
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.ssubscribe(self._topic) # type: ignore[attr-defined]
|
||||
|
||||
def _unsubscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
self._pubsub.sunsubscribe(self._topic) # type: ignore[attr-defined]
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined]
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
return "smessage"
|
||||
|
|
@ -38,6 +38,12 @@ class EmailType(StrEnum):
|
|||
EMAIL_REGISTER = auto()
|
||||
EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto()
|
||||
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto()
|
||||
TRIGGER_EVENTS_LIMIT_SANDBOX = auto()
|
||||
TRIGGER_EVENTS_LIMIT_PROFESSIONAL = auto()
|
||||
TRIGGER_EVENTS_USAGE_WARNING_SANDBOX = auto()
|
||||
TRIGGER_EVENTS_USAGE_WARNING_PROFESSIONAL = auto()
|
||||
API_RATE_LIMIT_LIMIT_SANDBOX = auto()
|
||||
API_RATE_LIMIT_WARNING_SANDBOX = auto()
|
||||
|
||||
|
||||
class EmailLanguage(StrEnum):
|
||||
|
|
@ -445,6 +451,78 @@ def create_default_email_config() -> EmailI18nConfig:
|
|||
branded_template_path="clean_document_job_mail_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.TRIGGER_EVENTS_LIMIT_SANDBOX: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You’ve reached your Sandbox Trigger Events limit",
|
||||
template_path="trigger_events_limit_template_en-US.html",
|
||||
branded_template_path="without-brand/trigger_events_limit_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的 Sandbox 触发事件额度已用尽",
|
||||
template_path="trigger_events_limit_template_zh-CN.html",
|
||||
branded_template_path="without-brand/trigger_events_limit_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.TRIGGER_EVENTS_LIMIT_PROFESSIONAL: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You’ve reached your monthly Trigger Events limit",
|
||||
template_path="trigger_events_limit_template_en-US.html",
|
||||
branded_template_path="without-brand/trigger_events_limit_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的月度触发事件额度已用尽",
|
||||
template_path="trigger_events_limit_template_zh-CN.html",
|
||||
branded_template_path="without-brand/trigger_events_limit_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.TRIGGER_EVENTS_USAGE_WARNING_SANDBOX: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You’re nearing your Sandbox Trigger Events limit",
|
||||
template_path="trigger_events_usage_warning_template_en-US.html",
|
||||
branded_template_path="without-brand/trigger_events_usage_warning_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的 Sandbox 触发事件额度接近上限",
|
||||
template_path="trigger_events_usage_warning_template_zh-CN.html",
|
||||
branded_template_path="without-brand/trigger_events_usage_warning_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.TRIGGER_EVENTS_USAGE_WARNING_PROFESSIONAL: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You’re nearing your Monthly Trigger Events limit",
|
||||
template_path="trigger_events_usage_warning_template_en-US.html",
|
||||
branded_template_path="without-brand/trigger_events_usage_warning_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的月度触发事件额度接近上限",
|
||||
template_path="trigger_events_usage_warning_template_zh-CN.html",
|
||||
branded_template_path="without-brand/trigger_events_usage_warning_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.API_RATE_LIMIT_LIMIT_SANDBOX: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You’ve reached your API Rate Limit",
|
||||
template_path="api_rate_limit_limit_template_en-US.html",
|
||||
branded_template_path="without-brand/api_rate_limit_limit_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的 API 速率额度已用尽",
|
||||
template_path="api_rate_limit_limit_template_zh-CN.html",
|
||||
branded_template_path="without-brand/api_rate_limit_limit_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.API_RATE_LIMIT_WARNING_SANDBOX: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="You’re nearing your API Rate Limit",
|
||||
template_path="api_rate_limit_warning_template_en-US.html",
|
||||
branded_template_path="without-brand/api_rate_limit_warning_template_en-US.html",
|
||||
),
|
||||
EmailLanguage.ZH_HANS: EmailTemplate(
|
||||
subject="您的 API 速率额度接近上限",
|
||||
template_path="api_rate_limit_warning_template_zh-CN.html",
|
||||
branded_template_path="without-brand/api_rate_limit_warning_template_zh-CN.html",
|
||||
),
|
||||
},
|
||||
EmailType.EMAIL_REGISTER: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Register Your {application_title} Account",
|
||||
|
|
|
|||
|
|
@ -177,6 +177,15 @@ def timezone(timezone_string):
|
|||
raise ValueError(error)
|
||||
|
||||
|
||||
def convert_datetime_to_date(field, target_timezone: str = ":tz"):
|
||||
if dify_config.DB_TYPE == "postgresql":
|
||||
return f"DATE(DATE_TRUNC('day', {field} AT TIME ZONE 'UTC' AT TIME ZONE {target_timezone}))"
|
||||
elif dify_config.DB_TYPE == "mysql":
|
||||
return f"DATE(CONVERT_TZ({field}, 'UTC', {target_timezone}))"
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database type: {dify_config.DB_TYPE}")
|
||||
|
||||
|
||||
def generate_string(n):
|
||||
letters_digits = string.ascii_letters + string.digits
|
||||
result = ""
|
||||
|
|
|
|||
|
|
@ -8,6 +8,12 @@ Create Date: 2024-01-07 04:07:34.482983
|
|||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '00bacef91f18'
|
||||
down_revision = '8ec536f3c800'
|
||||
|
|
@ -17,17 +23,31 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description', sa.Text(), nullable=False))
|
||||
batch_op.drop_column('description_str')
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description', sa.Text(), nullable=False))
|
||||
batch_op.drop_column('description_str')
|
||||
else:
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
|
||||
batch_op.drop_column('description_str')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False))
|
||||
batch_op.drop_column('description')
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False))
|
||||
batch_op.drop_column('description')
|
||||
else:
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
|
||||
batch_op.drop_column('description')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ from alembic import op
|
|||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '04c602f5dc9b'
|
||||
down_revision = '4ff534e1eb11'
|
||||
|
|
@ -19,15 +23,28 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tracing_app_configs',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tracing_provider', sa.String(length=255), nullable=True),
|
||||
sa.Column('tracing_config', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tracing_app_config_pkey')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('tracing_app_configs',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tracing_provider', sa.String(length=255), nullable=True),
|
||||
sa.Column('tracing_config', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tracing_app_config_pkey')
|
||||
)
|
||||
else:
|
||||
op.create_table('tracing_app_configs',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tracing_provider', sa.String(length=255), nullable=True),
|
||||
sa.Column('tracing_config', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.now(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tracing_app_config_pkey')
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,12 @@ import sqlalchemy as sa
|
|||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '053da0c1d756'
|
||||
down_revision = '4829e54d2fee'
|
||||
|
|
@ -18,16 +24,31 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tool_conversation_variables',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('conversation_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('variables_str', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_conversation_variables_pkey')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('tool_conversation_variables',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('conversation_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('variables_str', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_conversation_variables_pkey')
|
||||
)
|
||||
else:
|
||||
op.create_table('tool_conversation_variables',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('user_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('conversation_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('variables_str', models.types.LongText(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_conversation_variables_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('privacy_policy', sa.String(length=255), nullable=True))
|
||||
batch_op.alter_column('icon',
|
||||
|
|
|
|||
|
|
@ -9,6 +9,12 @@ import sqlalchemy as sa
|
|||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '114eed84c228'
|
||||
down_revision = 'c71211c8f604'
|
||||
|
|
@ -26,7 +32,13 @@ def upgrade():
|
|||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False))
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False))
|
||||
else:
|
||||
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -8,7 +8,11 @@ Create Date: 2024-07-05 14:30:59.472593
|
|||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '161cadc1af8d'
|
||||
|
|
@ -19,9 +23,16 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
|
||||
# Step 1: Add column without NOT NULL constraint
|
||||
op.add_column('dataset_permissions', sa.Column('tenant_id', sa.UUID(), nullable=False))
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
|
||||
# Step 1: Add column without NOT NULL constraint
|
||||
op.add_column('dataset_permissions', sa.Column('tenant_id', sa.UUID(), nullable=False))
|
||||
else:
|
||||
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
|
||||
# Step 1: Add column without NOT NULL constraint
|
||||
op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,12 @@ import sqlalchemy as sa
|
|||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '16fa53d9faec'
|
||||
down_revision = '8d2d099ceb74'
|
||||
|
|
@ -18,44 +24,87 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('provider_models',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('provider_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('model_type', sa.String(length=40), nullable=False),
|
||||
sa.Column('encrypted_config', sa.Text(), nullable=True),
|
||||
sa.Column('is_valid', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='provider_model_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('provider_models',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('provider_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('model_type', sa.String(length=40), nullable=False),
|
||||
sa.Column('encrypted_config', sa.Text(), nullable=True),
|
||||
sa.Column('is_valid', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='provider_model_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name')
|
||||
)
|
||||
else:
|
||||
op.create_table('provider_models',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('provider_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('model_type', sa.String(length=40), nullable=False),
|
||||
sa.Column('encrypted_config', models.types.LongText(), nullable=True),
|
||||
sa.Column('is_valid', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='provider_model_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('provider_models', schema=None) as batch_op:
|
||||
batch_op.create_index('provider_model_tenant_id_provider_idx', ['tenant_id', 'provider_name'], unique=False)
|
||||
|
||||
op.create_table('tenant_default_models',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('provider_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('model_type', sa.String(length=40), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_default_model_pkey')
|
||||
)
|
||||
if _is_pg(conn):
|
||||
op.create_table('tenant_default_models',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('provider_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('model_type', sa.String(length=40), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_default_model_pkey')
|
||||
)
|
||||
else:
|
||||
op.create_table('tenant_default_models',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('provider_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('model_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('model_type', sa.String(length=40), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_default_model_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('tenant_default_models', schema=None) as batch_op:
|
||||
batch_op.create_index('tenant_default_model_tenant_id_provider_type_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False)
|
||||
|
||||
op.create_table('tenant_preferred_model_providers',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('provider_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('preferred_provider_type', sa.String(length=40), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey')
|
||||
)
|
||||
if _is_pg(conn):
|
||||
op.create_table('tenant_preferred_model_providers',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('provider_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('preferred_provider_type', sa.String(length=40), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey')
|
||||
)
|
||||
else:
|
||||
op.create_table('tenant_preferred_model_providers',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('provider_name', sa.String(length=40), nullable=False),
|
||||
sa.Column('preferred_provider_type', sa.String(length=40), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op:
|
||||
batch_op.create_index('tenant_preferred_model_provider_tenant_provider_idx', ['tenant_id', 'provider_name'], unique=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@ Create Date: 2024-04-01 09:48:54.232201
|
|||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '17b5ab037c40'
|
||||
down_revision = 'a8f9b3c45e4a'
|
||||
|
|
@ -17,9 +21,14 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('data_source_type', sa.String(length=255), server_default=sa.text("'database'::character varying"), nullable=False))
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('data_source_type', sa.String(length=255), server_default=sa.text("'database'::character varying"), nullable=False))
|
||||
else:
|
||||
with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('data_source_type', sa.String(length=255), server_default=sa.text("'database'"), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ from alembic import op
|
|||
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '63a83fcf12ba'
|
||||
down_revision = '1787fbae959a'
|
||||
|
|
@ -19,21 +23,39 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('workflow__conversation_variables',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('conversation_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('data', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', 'conversation_id', name=op.f('workflow__conversation_variables_pkey'))
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('workflow__conversation_variables',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('conversation_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('data', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', 'conversation_id', name=op.f('workflow__conversation_variables_pkey'))
|
||||
)
|
||||
else:
|
||||
op.create_table('workflow__conversation_variables',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('conversation_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('data', models.types.LongText(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', 'conversation_id', name=op.f('workflow__conversation_variables_pkey'))
|
||||
)
|
||||
|
||||
with op.batch_alter_table('workflow__conversation_variables', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('workflow__conversation_variables_app_id_idx'), ['app_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('workflow__conversation_variables_created_at_idx'), ['created_at'], unique=False)
|
||||
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('conversation_variables', sa.Text(), server_default='{}', nullable=False))
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('conversation_variables', sa.Text(), server_default='{}', nullable=False))
|
||||
else:
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('conversation_variables', models.types.LongText(), default='{}', nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ from alembic import op
|
|||
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '0251a1c768cc'
|
||||
down_revision = 'bbadea11becb'
|
||||
|
|
@ -19,18 +23,35 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tidb_auth_bindings',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('cluster_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('cluster_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('active', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'CREATING'::character varying"), nullable=False),
|
||||
sa.Column('account', sa.String(length=255), nullable=False),
|
||||
sa.Column('password', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tidb_auth_bindings_pkey')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('tidb_auth_bindings',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('cluster_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('cluster_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('active', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'CREATING'::character varying"), nullable=False),
|
||||
sa.Column('account', sa.String(length=255), nullable=False),
|
||||
sa.Column('password', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tidb_auth_bindings_pkey')
|
||||
)
|
||||
else:
|
||||
op.create_table('tidb_auth_bindings',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('cluster_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('cluster_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('active', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'CREATING'"), nullable=False),
|
||||
sa.Column('account', sa.String(length=255), nullable=False),
|
||||
sa.Column('password', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tidb_auth_bindings_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('tidb_auth_bindings_active_idx', ['active'], unique=False)
|
||||
batch_op.create_index('tidb_auth_bindings_status_idx', ['status'], unique=False)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ from alembic import op
|
|||
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd57ba9ebb251'
|
||||
down_revision = '675b5321501b'
|
||||
|
|
@ -22,8 +26,14 @@ def upgrade():
|
|||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('parent_message_id', models.types.StringUUID(), nullable=True))
|
||||
|
||||
# Set parent_message_id for existing messages to uuid_nil() to distinguish them from new messages with actual parent IDs or NULLs
|
||||
op.execute('UPDATE messages SET parent_message_id = uuid_nil() WHERE parent_message_id IS NULL')
|
||||
# Set parent_message_id for existing messages to distinguish them from new messages with actual parent IDs or NULLs
|
||||
conn = op.get_bind()
|
||||
if _is_pg(conn):
|
||||
# PostgreSQL: Use uuid_nil() function
|
||||
op.execute('UPDATE messages SET parent_message_id = uuid_nil() WHERE parent_message_id IS NULL')
|
||||
else:
|
||||
# MySQL: Use a specific UUID value to represent nil
|
||||
op.execute("UPDATE messages SET parent_message_id = '00000000-0000-0000-0000-000000000000' WHERE parent_message_id IS NULL")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,11 @@ Create Date: 2024-09-24 09:22:43.570120
|
|||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import models.types
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
|
@ -19,30 +23,58 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=True)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=False)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
|
||||
batch_op.alter_column('segment_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('data_source_type',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('document_id',
|
||||
existing_type=models.types.StringUUID(),
|
||||
nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '33f5fac87f29'
|
||||
down_revision = '6af6a521a53e'
|
||||
|
|
@ -19,34 +23,66 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('external_knowledge_apis',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.String(length=255), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('settings', sa.Text(), nullable=True),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='external_knowledge_apis_pkey')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('external_knowledge_apis',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.String(length=255), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('settings', sa.Text(), nullable=True),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='external_knowledge_apis_pkey')
|
||||
)
|
||||
else:
|
||||
op.create_table('external_knowledge_apis',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.String(length=255), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('settings', models.types.LongText(), nullable=True),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='external_knowledge_apis_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('external_knowledge_apis', schema=None) as batch_op:
|
||||
batch_op.create_index('external_knowledge_apis_name_idx', ['name'], unique=False)
|
||||
batch_op.create_index('external_knowledge_apis_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
op.create_table('external_knowledge_bindings',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('external_knowledge_api_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('external_knowledge_id', sa.Text(), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='external_knowledge_bindings_pkey')
|
||||
)
|
||||
if _is_pg(conn):
|
||||
op.create_table('external_knowledge_bindings',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('external_knowledge_api_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('external_knowledge_id', sa.Text(), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='external_knowledge_bindings_pkey')
|
||||
)
|
||||
else:
|
||||
op.create_table('external_knowledge_bindings',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('external_knowledge_api_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('external_knowledge_id', sa.String(length=512), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='external_knowledge_bindings_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('external_knowledge_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('external_knowledge_bindings_dataset_idx', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('external_knowledge_bindings_external_knowledge_api_idx', ['external_knowledge_api_id'], unique=False)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,10 @@ branch_labels = None
|
|||
depends_on = None
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
def upgrade():
|
||||
def _has_name_or_size_column() -> bool:
|
||||
# We cannot access the database in offline mode, so assume
|
||||
|
|
@ -46,14 +50,26 @@ def upgrade():
|
|||
if _has_name_or_size_column():
|
||||
return
|
||||
|
||||
with op.batch_alter_table("tool_files", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("name", sa.String(), nullable=True))
|
||||
batch_op.add_column(sa.Column("size", sa.Integer(), nullable=True))
|
||||
op.execute("UPDATE tool_files SET name = '' WHERE name IS NULL")
|
||||
op.execute("UPDATE tool_files SET size = -1 WHERE size IS NULL")
|
||||
with op.batch_alter_table("tool_files", schema=None) as batch_op:
|
||||
batch_op.alter_column("name", existing_type=sa.String(), nullable=False)
|
||||
batch_op.alter_column("size", existing_type=sa.Integer(), nullable=False)
|
||||
if _is_pg(conn):
|
||||
# PostgreSQL: Keep original syntax
|
||||
with op.batch_alter_table("tool_files", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("name", sa.String(), nullable=True))
|
||||
batch_op.add_column(sa.Column("size", sa.Integer(), nullable=True))
|
||||
op.execute("UPDATE tool_files SET name = '' WHERE name IS NULL")
|
||||
op.execute("UPDATE tool_files SET size = -1 WHERE size IS NULL")
|
||||
with op.batch_alter_table("tool_files", schema=None) as batch_op:
|
||||
batch_op.alter_column("name", existing_type=sa.String(), nullable=False)
|
||||
batch_op.alter_column("size", existing_type=sa.Integer(), nullable=False)
|
||||
else:
|
||||
# MySQL: Use compatible syntax
|
||||
with op.batch_alter_table("tool_files", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("name", sa.String(length=255), nullable=True))
|
||||
batch_op.add_column(sa.Column("size", sa.Integer(), nullable=True))
|
||||
op.execute("UPDATE tool_files SET name = '' WHERE name IS NULL")
|
||||
op.execute("UPDATE tool_files SET size = -1 WHERE size IS NULL")
|
||||
with op.batch_alter_table("tool_files", schema=None) as batch_op:
|
||||
batch_op.alter_column("name", existing_type=sa.String(length=255), nullable=False)
|
||||
batch_op.alter_column("size", existing_type=sa.Integer(), nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '43fa78bc3b7d'
|
||||
down_revision = '0251a1c768cc'
|
||||
|
|
@ -19,13 +23,25 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('whitelists',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('category', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='whitelists_pkey')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('whitelists',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('category', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='whitelists_pkey')
|
||||
)
|
||||
else:
|
||||
op.create_table('whitelists',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('category', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='whitelists_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('whitelists', schema=None) as batch_op:
|
||||
batch_op.create_index('whitelists_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '08ec4f75af5e'
|
||||
down_revision = 'ddcc8bbef391'
|
||||
|
|
@ -19,14 +23,26 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('account_plugin_permissions',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('install_permission', sa.String(length=16), server_default='everyone', nullable=False),
|
||||
sa.Column('debug_permission', sa.String(length=16), server_default='noone', nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='account_plugin_permission_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('account_plugin_permissions',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('install_permission', sa.String(length=16), server_default='everyone', nullable=False),
|
||||
sa.Column('debug_permission', sa.String(length=16), server_default='noone', nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='account_plugin_permission_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin')
|
||||
)
|
||||
else:
|
||||
op.create_table('account_plugin_permissions',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('install_permission', sa.String(length=16), server_default='everyone', nullable=False),
|
||||
sa.Column('debug_permission', sa.String(length=16), server_default='noone', nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='account_plugin_permission_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin')
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'f4d7ce70a7ca'
|
||||
down_revision = '93ad8c19c40b'
|
||||
|
|
@ -19,23 +23,43 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('upload_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('source_url',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
existing_nullable=False,
|
||||
existing_server_default=sa.text("''::character varying"))
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('upload_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('source_url',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
existing_nullable=False,
|
||||
existing_server_default=sa.text("''::character varying"))
|
||||
else:
|
||||
with op.batch_alter_table('upload_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('source_url',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
existing_nullable=False,
|
||||
existing_default=sa.text("''"))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('upload_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('source_url',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
existing_nullable=False,
|
||||
existing_server_default=sa.text("''::character varying"))
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('upload_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('source_url',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
existing_nullable=False,
|
||||
existing_server_default=sa.text("''::character varying"))
|
||||
else:
|
||||
with op.batch_alter_table('upload_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('source_url',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
existing_nullable=False,
|
||||
existing_default=sa.text("''"))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@ Create Date: 2024-11-01 06:22:27.981398
|
|||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
|
@ -19,49 +22,91 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
op.execute("UPDATE recommended_apps SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
|
||||
op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
|
||||
op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
|
||||
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
nullable=False)
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.TEXT(),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=models.types.LongText(),
|
||||
nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('sites', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
|
||||
batch_op.alter_column('custom_disclaimer',
|
||||
existing_type=models.types.LongText(),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '09a8d1878d9b'
|
||||
down_revision = 'd07474999927'
|
||||
|
|
@ -19,55 +23,103 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('inputs',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=False)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('inputs',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.alter_column('inputs',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=False)
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.alter_column('inputs',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('inputs',
|
||||
existing_type=sa.JSON(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.alter_column('inputs',
|
||||
existing_type=sa.JSON(),
|
||||
nullable=False)
|
||||
|
||||
op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL")
|
||||
op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL")
|
||||
op.execute("UPDATE workflows SET features = '' WHERE features IS NULL")
|
||||
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=postgresql.TIMESTAMP(),
|
||||
nullable=False)
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=postgresql.TIMESTAMP(),
|
||||
nullable=False)
|
||||
else:
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=False)
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=sa.TIMESTAMP(),
|
||||
nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=postgresql.TIMESTAMP(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=True)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=postgresql.TIMESTAMP(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=sa.TEXT(),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.alter_column('updated_at',
|
||||
existing_type=sa.TIMESTAMP(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('features',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
batch_op.alter_column('graph',
|
||||
existing_type=models.types.LongText(),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.alter_column('inputs',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=True)
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.alter_column('inputs',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('inputs',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=True)
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('inputs',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=True)
|
||||
else:
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.alter_column('inputs',
|
||||
existing_type=sa.JSON(),
|
||||
nullable=True)
|
||||
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.alter_column('inputs',
|
||||
existing_type=sa.JSON(),
|
||||
nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e19037032219'
|
||||
down_revision = 'd7999dfa4aae'
|
||||
|
|
@ -19,27 +23,53 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('child_chunks',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('segment_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('position', sa.Integer(), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('word_count', sa.Integer(), nullable=False),
|
||||
sa.Column('index_node_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('index_node_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('indexing_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('error', sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='child_chunk_pkey')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('child_chunks',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('segment_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('position', sa.Integer(), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('word_count', sa.Integer(), nullable=False),
|
||||
sa.Column('index_node_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('index_node_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('indexing_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('error', sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='child_chunk_pkey')
|
||||
)
|
||||
else:
|
||||
op.create_table('child_chunks',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('segment_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('position', sa.Integer(), nullable=False),
|
||||
sa.Column('content', models.types.LongText(), nullable=False),
|
||||
sa.Column('word_count', sa.Integer(), nullable=False),
|
||||
sa.Column('index_node_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('index_node_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'"), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('indexing_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('error', models.types.LongText(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='child_chunk_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('child_chunks', schema=None) as batch_op:
|
||||
batch_op.create_index('child_chunk_dataset_id_idx', ['tenant_id', 'dataset_id', 'document_id', 'segment_id', 'index_node_id'], unique=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '11b07f66c737'
|
||||
down_revision = 'cf8f4fc45278'
|
||||
|
|
@ -25,15 +29,30 @@ def upgrade():
|
|||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tool_providers',
|
||||
sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False),
|
||||
sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True),
|
||||
sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
|
||||
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
|
||||
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('tool_providers',
|
||||
sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False),
|
||||
sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True),
|
||||
sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
|
||||
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
|
||||
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
|
||||
)
|
||||
else:
|
||||
op.create_table('tool_providers',
|
||||
sa.Column('id', models.types.StringUUID(), autoincrement=False, nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), autoincrement=False, nullable=False),
|
||||
sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False),
|
||||
sa.Column('encrypted_credentials', models.types.LongText(), autoincrement=False, nullable=True),
|
||||
sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
|
||||
sa.Column('created_at', sa.TIMESTAMP(), server_default=sa.func.current_timestamp(), autoincrement=False, nullable=False),
|
||||
sa.Column('updated_at', sa.TIMESTAMP(), server_default=sa.func.current_timestamp(), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '923752d42eb6'
|
||||
down_revision = 'e19037032219'
|
||||
|
|
@ -19,15 +23,29 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('dataset_auto_disable_logs',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('dataset_auto_disable_logs',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey')
|
||||
)
|
||||
else:
|
||||
op.create_table('dataset_auto_disable_logs',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op:
|
||||
batch_op.create_index('dataset_auto_disable_log_created_atx', ['created_at'], unique=False)
|
||||
batch_op.create_index('dataset_auto_disable_log_dataset_idx', ['dataset_id'], unique=False)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'f051706725cc'
|
||||
down_revision = 'ee79d9b1c156'
|
||||
|
|
@ -19,14 +23,27 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('rate_limit_logs',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('subscription_plan', sa.String(length=255), nullable=False),
|
||||
sa.Column('operation', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='rate_limit_log_pkey')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('rate_limit_logs',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('subscription_plan', sa.String(length=255), nullable=False),
|
||||
sa.Column('operation', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='rate_limit_log_pkey')
|
||||
)
|
||||
else:
|
||||
op.create_table('rate_limit_logs',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('subscription_plan', sa.String(length=255), nullable=False),
|
||||
sa.Column('operation', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='rate_limit_log_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('rate_limit_logs', schema=None) as batch_op:
|
||||
batch_op.create_index('rate_limit_log_operation_idx', ['operation'], unique=False)
|
||||
batch_op.create_index('rate_limit_log_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd20049ed0af6'
|
||||
down_revision = 'f051706725cc'
|
||||
|
|
@ -19,34 +23,66 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('dataset_metadata_bindings',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('metadata_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='dataset_metadata_binding_pkey')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('dataset_metadata_bindings',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('metadata_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='dataset_metadata_binding_pkey')
|
||||
)
|
||||
else:
|
||||
op.create_table('dataset_metadata_bindings',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('metadata_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='dataset_metadata_binding_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('dataset_metadata_binding_dataset_idx', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('dataset_metadata_binding_document_idx', ['document_id'], unique=False)
|
||||
batch_op.create_index('dataset_metadata_binding_metadata_idx', ['metadata_id'], unique=False)
|
||||
batch_op.create_index('dataset_metadata_binding_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
||||
op.create_table('dataset_metadatas',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('type', sa.String(length=255), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='dataset_metadata_pkey')
|
||||
)
|
||||
if _is_pg(conn):
|
||||
# PostgreSQL: Keep original syntax
|
||||
op.create_table('dataset_metadatas',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('type', sa.String(length=255), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='dataset_metadata_pkey')
|
||||
)
|
||||
else:
|
||||
# MySQL: Use compatible syntax
|
||||
op.create_table('dataset_metadatas',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('type', sa.String(length=255), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='dataset_metadata_pkey')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op:
|
||||
batch_op.create_index('dataset_metadata_dataset_idx', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('dataset_metadata_tenant_idx', ['tenant_id'], unique=False)
|
||||
|
|
@ -54,23 +90,31 @@ def upgrade():
|
|||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('built_in_field_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.alter_column('doc_metadata',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
type_=postgresql.JSONB(astext_type=sa.Text()),
|
||||
existing_nullable=True)
|
||||
batch_op.create_index('document_metadata_idx', ['doc_metadata'], unique=False, postgresql_using='gin')
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.alter_column('doc_metadata',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
type_=postgresql.JSONB(astext_type=sa.Text()),
|
||||
existing_nullable=True)
|
||||
batch_op.create_index('document_metadata_idx', ['doc_metadata'], unique=False, postgresql_using='gin')
|
||||
else:
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.drop_index('document_metadata_idx', postgresql_using='gin')
|
||||
batch_op.alter_column('doc_metadata',
|
||||
existing_type=postgresql.JSONB(astext_type=sa.Text()),
|
||||
type_=postgresql.JSON(astext_type=sa.Text()),
|
||||
existing_nullable=True)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.drop_index('document_metadata_idx', postgresql_using='gin')
|
||||
batch_op.alter_column('doc_metadata',
|
||||
existing_type=postgresql.JSONB(astext_type=sa.Text()),
|
||||
type_=postgresql.JSON(astext_type=sa.Text()),
|
||||
existing_nullable=True)
|
||||
else:
|
||||
pass
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_column('built_in_field_enabled')
|
||||
|
|
|
|||
|
|
@ -17,10 +17,23 @@ branch_labels = None
|
|||
depends_on = None
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('marked_name', sa.String(), nullable=False, server_default=''))
|
||||
batch_op.add_column(sa.Column('marked_comment', sa.String(), nullable=False, server_default=''))
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
# PostgreSQL: Keep original syntax
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('marked_name', sa.String(), nullable=False, server_default=''))
|
||||
batch_op.add_column(sa.Column('marked_comment', sa.String(), nullable=False, server_default=''))
|
||||
else:
|
||||
# MySQL: Use compatible syntax
|
||||
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('marked_name', sa.String(length=255), nullable=False, server_default=''))
|
||||
batch_op.add_column(sa.Column('marked_comment', sa.String(length=255), nullable=False, server_default=''))
|
||||
|
||||
|
||||
def downgrade():
|
||||
|
|
|
|||
|
|
@ -11,6 +11,10 @@ from alembic import op
|
|||
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2adcbe1f5dfb"
|
||||
down_revision = "d28f2004b072"
|
||||
|
|
@ -20,24 +24,46 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"workflow_draft_variables",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("last_edited_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("node_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.String(length=255), nullable=False),
|
||||
sa.Column("selector", sa.String(length=255), nullable=False),
|
||||
sa.Column("value_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("value", sa.Text(), nullable=False),
|
||||
sa.Column("visible", sa.Boolean(), nullable=False),
|
||||
sa.Column("editable", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("workflow_draft_variables_pkey")),
|
||||
sa.UniqueConstraint("app_id", "node_id", "name", name=op.f("workflow_draft_variables_app_id_key")),
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table(
|
||||
"workflow_draft_variables",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("last_edited_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("node_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.String(length=255), nullable=False),
|
||||
sa.Column("selector", sa.String(length=255), nullable=False),
|
||||
sa.Column("value_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("value", sa.Text(), nullable=False),
|
||||
sa.Column("visible", sa.Boolean(), nullable=False),
|
||||
sa.Column("editable", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("workflow_draft_variables_pkey")),
|
||||
sa.UniqueConstraint("app_id", "node_id", "name", name=op.f("workflow_draft_variables_app_id_key")),
|
||||
)
|
||||
else:
|
||||
op.create_table(
|
||||
"workflow_draft_variables",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("last_edited_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("node_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.String(length=255), nullable=False),
|
||||
sa.Column("selector", sa.String(length=255), nullable=False),
|
||||
sa.Column("value_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("value", models.types.LongText(), nullable=False),
|
||||
sa.Column("visible", sa.Boolean(), nullable=False),
|
||||
sa.Column("editable", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("workflow_draft_variables_pkey")),
|
||||
sa.UniqueConstraint("app_id", "node_id", "name", name=op.f("workflow_draft_variables_app_id_key")),
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,10 @@ Create Date: 2025-06-06 14:24:44.213018
|
|||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
|
|
@ -18,19 +22,30 @@ depends_on = None
|
|||
|
||||
|
||||
def upgrade():
|
||||
# `CREATE INDEX CONCURRENTLY` cannot run within a transaction, so use the `autocommit_block`
|
||||
# context manager to wrap the index creation statement.
|
||||
# Reference:
|
||||
#
|
||||
# - https://www.postgresql.org/docs/current/sql-createindex.html#:~:text=Another%20difference%20is,CREATE%20INDEX%20CONCURRENTLY%20cannot.
|
||||
# - https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.migration.MigrationContext.autocommit_block
|
||||
with op.get_context().autocommit_block():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
# `CREATE INDEX CONCURRENTLY` cannot run within a transaction, so use the `autocommit_block`
|
||||
# context manager to wrap the index creation statement.
|
||||
# Reference:
|
||||
#
|
||||
# - https://www.postgresql.org/docs/current/sql-createindex.html#:~:text=Another%20difference%20is,CREATE%20INDEX%20CONCURRENTLY%20cannot.
|
||||
# - https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.migration.MigrationContext.autocommit_block
|
||||
with op.get_context().autocommit_block():
|
||||
op.create_index(
|
||||
op.f('workflow_node_executions_tenant_id_idx'),
|
||||
"workflow_node_executions",
|
||||
['tenant_id', 'workflow_id', 'node_id', sa.literal_column('created_at DESC')],
|
||||
unique=False,
|
||||
postgresql_concurrently=True,
|
||||
)
|
||||
else:
|
||||
op.create_index(
|
||||
op.f('workflow_node_executions_tenant_id_idx'),
|
||||
"workflow_node_executions",
|
||||
['tenant_id', 'workflow_id', 'node_id', sa.literal_column('created_at DESC')],
|
||||
unique=False,
|
||||
postgresql_concurrently=True,
|
||||
)
|
||||
|
||||
with op.batch_alter_table('workflow_draft_variables', schema=None) as batch_op:
|
||||
|
|
@ -51,8 +66,13 @@ def downgrade():
|
|||
# Reference:
|
||||
#
|
||||
# https://www.postgresql.org/docs/current/sql-createindex.html#:~:text=Another%20difference%20is,CREATE%20INDEX%20CONCURRENTLY%20cannot.
|
||||
with op.get_context().autocommit_block():
|
||||
op.drop_index(op.f('workflow_node_executions_tenant_id_idx'), postgresql_concurrently=True)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
with op.get_context().autocommit_block():
|
||||
op.drop_index(op.f('workflow_node_executions_tenant_id_idx'), postgresql_concurrently=True)
|
||||
else:
|
||||
op.drop_index(op.f('workflow_node_executions_tenant_id_idx'))
|
||||
|
||||
with op.batch_alter_table('workflow_draft_variables', schema=None) as batch_op:
|
||||
batch_op.drop_column('node_execution_id')
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '58eb7bdb93fe'
|
||||
down_revision = '0ab65e1cc7fa'
|
||||
|
|
@ -19,40 +23,80 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('app_mcp_servers',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.String(length=255), nullable=False),
|
||||
sa.Column('server_code', sa.String(length=255), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False),
|
||||
sa.Column('parameters', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'app_id', name='unique_app_mcp_server_tenant_app_id'),
|
||||
sa.UniqueConstraint('server_code', name='unique_app_mcp_server_server_code')
|
||||
)
|
||||
op.create_table('tool_mcp_providers',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('name', sa.String(length=40), nullable=False),
|
||||
sa.Column('server_identifier', sa.String(length=24), nullable=False),
|
||||
sa.Column('server_url', sa.Text(), nullable=False),
|
||||
sa.Column('server_url_hash', sa.String(length=64), nullable=False),
|
||||
sa.Column('icon', sa.String(length=255), nullable=True),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('user_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
|
||||
sa.Column('authed', sa.Boolean(), nullable=False),
|
||||
sa.Column('tools', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_mcp_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'name', name='unique_mcp_provider_name'),
|
||||
sa.UniqueConstraint('tenant_id', 'server_identifier', name='unique_mcp_provider_server_identifier'),
|
||||
sa.UniqueConstraint('tenant_id', 'server_url_hash', name='unique_mcp_provider_server_url')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('app_mcp_servers',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.String(length=255), nullable=False),
|
||||
sa.Column('server_code', sa.String(length=255), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False),
|
||||
sa.Column('parameters', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'app_id', name='unique_app_mcp_server_tenant_app_id'),
|
||||
sa.UniqueConstraint('server_code', name='unique_app_mcp_server_server_code')
|
||||
)
|
||||
else:
|
||||
op.create_table('app_mcp_servers',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.String(length=255), nullable=False),
|
||||
sa.Column('server_code', sa.String(length=255), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'"), nullable=False),
|
||||
sa.Column('parameters', models.types.LongText(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'app_id', name='unique_app_mcp_server_tenant_app_id'),
|
||||
sa.UniqueConstraint('server_code', name='unique_app_mcp_server_server_code')
|
||||
)
|
||||
if _is_pg(conn):
|
||||
op.create_table('tool_mcp_providers',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('name', sa.String(length=40), nullable=False),
|
||||
sa.Column('server_identifier', sa.String(length=24), nullable=False),
|
||||
sa.Column('server_url', sa.Text(), nullable=False),
|
||||
sa.Column('server_url_hash', sa.String(length=64), nullable=False),
|
||||
sa.Column('icon', sa.String(length=255), nullable=True),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('user_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
|
||||
sa.Column('authed', sa.Boolean(), nullable=False),
|
||||
sa.Column('tools', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_mcp_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'name', name='unique_mcp_provider_name'),
|
||||
sa.UniqueConstraint('tenant_id', 'server_identifier', name='unique_mcp_provider_server_identifier'),
|
||||
sa.UniqueConstraint('tenant_id', 'server_url_hash', name='unique_mcp_provider_server_url')
|
||||
)
|
||||
else:
|
||||
op.create_table('tool_mcp_providers',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=40), nullable=False),
|
||||
sa.Column('server_identifier', sa.String(length=24), nullable=False),
|
||||
sa.Column('server_url', models.types.LongText(), nullable=False),
|
||||
sa.Column('server_url_hash', sa.String(length=64), nullable=False),
|
||||
sa.Column('icon', sa.String(length=255), nullable=True),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('user_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('encrypted_credentials', models.types.LongText(), nullable=True),
|
||||
sa.Column('authed', sa.Boolean(), nullable=False),
|
||||
sa.Column('tools', models.types.LongText(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_mcp_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'name', name='unique_mcp_provider_name'),
|
||||
sa.UniqueConstraint('tenant_id', 'server_identifier', name='unique_mcp_provider_server_identifier'),
|
||||
sa.UniqueConstraint('tenant_id', 'server_url_hash', name='unique_mcp_provider_server_url')
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '1c9ba48be8e4'
|
||||
down_revision = '58eb7bdb93fe'
|
||||
|
|
@ -40,7 +44,11 @@ def upgrade():
|
|||
# The ability to specify source timestamp has been removed because its type signature is incompatible with
|
||||
# PostgreSQL 18's `uuidv7` function. This capability is rarely needed in practice, as IDs can be
|
||||
# generated and controlled within the application layer.
|
||||
op.execute(sa.text(r"""
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
# PostgreSQL: Create uuidv7 functions
|
||||
op.execute(sa.text(r"""
|
||||
/* Main function to generate a uuidv7 value with millisecond precision */
|
||||
CREATE FUNCTION uuidv7() RETURNS uuid
|
||||
AS
|
||||
|
|
@ -63,7 +71,7 @@ COMMENT ON FUNCTION uuidv7 IS
|
|||
'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness';
|
||||
"""))
|
||||
|
||||
op.execute(sa.text(r"""
|
||||
op.execute(sa.text(r"""
|
||||
CREATE FUNCTION uuidv7_boundary(timestamptz) RETURNS uuid
|
||||
AS
|
||||
$$
|
||||
|
|
@ -79,8 +87,15 @@ COMMENT ON FUNCTION uuidv7_boundary(timestamptz) IS
|
|||
'Generate a non-random uuidv7 with the given timestamp (first 48 bits) and all random bits to 0. As the smallest possible uuidv7 for that timestamp, it may be used as a boundary for partitions.';
|
||||
"""
|
||||
))
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.execute(sa.text("DROP FUNCTION uuidv7"))
|
||||
op.execute(sa.text("DROP FUNCTION uuidv7_boundary"))
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.execute(sa.text("DROP FUNCTION uuidv7"))
|
||||
op.execute(sa.text("DROP FUNCTION uuidv7_boundary"))
|
||||
else:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '71f5020c6470'
|
||||
down_revision = '1c9ba48be8e4'
|
||||
|
|
@ -19,31 +23,63 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tool_oauth_system_clients',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('plugin_id', sa.String(length=512), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
|
||||
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
|
||||
)
|
||||
op.create_table('tool_oauth_tenant_clients',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('plugin_id', sa.String(length=512), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('tool_oauth_system_clients',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('plugin_id', sa.String(length=512), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
|
||||
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
|
||||
)
|
||||
else:
|
||||
op.create_table('tool_oauth_system_clients',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('plugin_id', sa.String(length=512), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('encrypted_oauth_params', models.types.LongText(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
|
||||
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
|
||||
)
|
||||
if _is_pg(conn):
|
||||
op.create_table('tool_oauth_tenant_clients',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('plugin_id', sa.String(length=512), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client')
|
||||
)
|
||||
else:
|
||||
op.create_table('tool_oauth_tenant_clients',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('plugin_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('encrypted_oauth_params', models.types.LongText(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False))
|
||||
batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||
batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False))
|
||||
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name'])
|
||||
if _is_pg(conn):
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False))
|
||||
batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||
batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False))
|
||||
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name'])
|
||||
else:
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'"), nullable=False))
|
||||
batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||
batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'"), nullable=False))
|
||||
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name'])
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ import models as models
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '8bcc02c9bd07'
|
||||
down_revision = '375fe79ead14'
|
||||
|
|
@ -19,19 +23,36 @@ depends_on = None
|
|||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tenant_plugin_auto_upgrade_strategies',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False),
|
||||
sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False),
|
||||
sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False),
|
||||
sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
|
||||
sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('tenant_plugin_auto_upgrade_strategies',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False),
|
||||
sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False),
|
||||
sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False),
|
||||
sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
|
||||
sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy')
|
||||
)
|
||||
else:
|
||||
op.create_table('tenant_plugin_auto_upgrade_strategies',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False),
|
||||
sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False),
|
||||
sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False),
|
||||
sa.Column('exclude_plugins', sa.JSON(), nullable=False),
|
||||
sa.Column('include_plugins', sa.JSON(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,10 @@ Create Date: 2025-07-24 14:50:48.779833
|
|||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
|
|
@ -18,8 +22,18 @@ depends_on = None
|
|||
|
||||
|
||||
def upgrade():
|
||||
op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'::character varying")
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'::character varying")
|
||||
else:
|
||||
op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'")
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'")
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'::character varying")
|
||||
else:
|
||||
op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'")
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue