mirror of
https://github.com/langgenius/dify.git
synced 2026-04-22 07:46:31 +08:00
merge main
This commit is contained in:
commit
a4e2ef6b0c
30
.github/workflows/api-tests.yml
vendored
30
.github/workflows/api-tests.yml
vendored
@ -39,25 +39,11 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --project api --dev
|
run: uv sync --project api --dev
|
||||||
|
|
||||||
- name: Run Unit tests
|
|
||||||
run: |
|
|
||||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
|
||||||
|
|
||||||
- name: Run pyrefly check
|
- name: Run pyrefly check
|
||||||
run: |
|
run: |
|
||||||
cd api
|
cd api
|
||||||
uv add --dev pyrefly
|
uv add --dev pyrefly
|
||||||
uv run pyrefly check || true
|
uv run pyrefly check || true
|
||||||
- name: Coverage Summary
|
|
||||||
run: |
|
|
||||||
set -x
|
|
||||||
# Extract coverage percentage and create a summary
|
|
||||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
|
||||||
|
|
||||||
# Create a detailed coverage summary
|
|
||||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
|
||||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
|
||||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
|
||||||
|
|
||||||
- name: Run dify config tests
|
- name: Run dify config tests
|
||||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||||
@ -93,3 +79,19 @@ jobs:
|
|||||||
|
|
||||||
- name: Run TestContainers
|
- name: Run TestContainers
|
||||||
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
||||||
|
|
||||||
|
- name: Run Unit tests
|
||||||
|
run: |
|
||||||
|
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||||
|
|
||||||
|
- name: Coverage Summary
|
||||||
|
run: |
|
||||||
|
set -x
|
||||||
|
# Extract coverage percentage and create a summary
|
||||||
|
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||||
|
|
||||||
|
# Create a detailed coverage summary
|
||||||
|
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||||
|
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
||||||
|
|
||||||
|
|||||||
3
.github/workflows/expose_service_ports.sh
vendored
3
.github/workflows/expose_service_ports.sh
vendored
@ -1,6 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
|
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
|
||||||
|
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
|
||||||
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
|
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
|
||||||
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
|
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
|
||||||
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
||||||
@ -13,4 +14,4 @@ yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.ya
|
|||||||
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
|
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
|
||||||
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
|
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
|
||||||
|
|
||||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||||
|
|||||||
@ -14,7 +14,7 @@ The codebase is split into:
|
|||||||
|
|
||||||
- Run backend CLI commands through `uv run --project api <command>`.
|
- Run backend CLI commands through `uv run --project api <command>`.
|
||||||
|
|
||||||
- Backend QA gate requires passing `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before review.
|
- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
|
||||||
|
|
||||||
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
|
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
|
||||||
|
|
||||||
|
|||||||
10
README.md
10
README.md
@ -129,8 +129,18 @@ Star Dify on GitHub and be instantly notified of new releases.
|
|||||||
|
|
||||||
## Advanced Setup
|
## Advanced Setup
|
||||||
|
|
||||||
|
### Custom configurations
|
||||||
|
|
||||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||||
|
|
||||||
|
### Metrics Monitoring with Grafana
|
||||||
|
|
||||||
|
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
|
||||||
|
|
||||||
|
- [Grafana Dashboard by @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard)
|
||||||
|
|
||||||
|
### Deployment with Kubernetes
|
||||||
|
|
||||||
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||||
|
|
||||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||||
|
|||||||
@ -189,6 +189,11 @@ class PluginConfig(BaseSettings):
|
|||||||
default="plugin-api-key",
|
default="plugin-api-key",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||||
|
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||||
|
default=300.0,
|
||||||
|
)
|
||||||
|
|
||||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||||
|
|
||||||
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
||||||
@ -543,7 +548,7 @@ class UpdateConfig(BaseSettings):
|
|||||||
|
|
||||||
class WorkflowVariableTruncationConfig(BaseSettings):
|
class WorkflowVariableTruncationConfig(BaseSettings):
|
||||||
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
||||||
# 100KB
|
# 1000 KiB
|
||||||
1024_000,
|
1024_000,
|
||||||
description="Maximum size for variable to trigger final truncation.",
|
description="Maximum size for variable to trigger final truncation.",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -55,3 +55,12 @@ else:
|
|||||||
"properties",
|
"properties",
|
||||||
}
|
}
|
||||||
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||||
|
|
||||||
|
COOKIE_NAME_ACCESS_TOKEN = "access_token"
|
||||||
|
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
|
||||||
|
COOKIE_NAME_PASSPORT = "passport"
|
||||||
|
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
|
||||||
|
|
||||||
|
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
|
||||||
|
HEADER_NAME_APP_CODE = "X-App-Code"
|
||||||
|
HEADER_NAME_PASSPORT = "X-App-Passport"
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from constants.languages import supported_language
|
|||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.wraps import only_edition_cloud
|
from controllers.console.wraps import only_edition_cloud
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from libs.token import extract_access_token
|
||||||
from models.model import App, InstalledApp, RecommendedApp
|
from models.model import App, InstalledApp, RecommendedApp
|
||||||
|
|
||||||
|
|
||||||
@ -24,19 +25,9 @@ def admin_required(view: Callable[P, R]):
|
|||||||
if not dify_config.ADMIN_API_KEY:
|
if not dify_config.ADMIN_API_KEY:
|
||||||
raise Unauthorized("API key is invalid.")
|
raise Unauthorized("API key is invalid.")
|
||||||
|
|
||||||
auth_header = request.headers.get("Authorization")
|
auth_token = extract_access_token(request)
|
||||||
if auth_header is None:
|
if not auth_token:
|
||||||
raise Unauthorized("Authorization header is missing.")
|
raise Unauthorized("Authorization header is missing.")
|
||||||
|
|
||||||
if " " not in auth_header:
|
|
||||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
|
||||||
|
|
||||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
|
||||||
auth_scheme = auth_scheme.lower()
|
|
||||||
|
|
||||||
if auth_scheme != "bearer":
|
|
||||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
|
||||||
|
|
||||||
if auth_token != dify_config.ADMIN_API_KEY:
|
if auth_token != dify_config.ADMIN_API_KEY:
|
||||||
raise Unauthorized("API key is invalid.")
|
raise Unauthorized("API key is invalid.")
|
||||||
|
|
||||||
@ -70,15 +61,17 @@ class InsertExploreAppListApi(Resource):
|
|||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
@admin_required
|
@admin_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("desc", type=str, location="json")
|
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("copyright", type=str, location="json")
|
.add_argument("desc", type=str, location="json")
|
||||||
parser.add_argument("privacy_policy", type=str, location="json")
|
.add_argument("copyright", type=str, location="json")
|
||||||
parser.add_argument("custom_disclaimer", type=str, location="json")
|
.add_argument("privacy_policy", type=str, location="json")
|
||||||
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
.add_argument("custom_disclaimer", type=str, location="json")
|
||||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
||||||
|
|||||||
@ -7,13 +7,12 @@ from werkzeug.exceptions import Forbidden
|
|||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from models.model import ApiToken, App
|
from models.model import ApiToken, App
|
||||||
|
|
||||||
from . import api, console_ns
|
from . import api, console_ns
|
||||||
from .wraps import account_initialization_required, setup_required
|
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
|
|
||||||
api_key_fields = {
|
api_key_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
@ -57,9 +56,9 @@ class BaseApiKeyListResource(Resource):
|
|||||||
def get(self, resource_id):
|
def get(self, resource_id):
|
||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||||
keys = db.session.scalars(
|
keys = db.session.scalars(
|
||||||
select(ApiToken).where(
|
select(ApiToken).where(
|
||||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||||
@ -68,15 +67,12 @@ class BaseApiKeyListResource(Resource):
|
|||||||
return {"items": keys}
|
return {"items": keys}
|
||||||
|
|
||||||
@marshal_with(api_key_fields)
|
@marshal_with(api_key_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, resource_id):
|
def post(self, resource_id):
|
||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
current_key_count = (
|
current_key_count = (
|
||||||
db.session.query(ApiToken)
|
db.session.query(ApiToken)
|
||||||
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||||
@ -93,7 +89,7 @@ class BaseApiKeyListResource(Resource):
|
|||||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||||
api_token = ApiToken()
|
api_token = ApiToken()
|
||||||
setattr(api_token, self.resource_id_field, resource_id)
|
setattr(api_token, self.resource_id_field, resource_id)
|
||||||
api_token.tenant_id = current_user.current_tenant_id
|
api_token.tenant_id = current_tenant_id
|
||||||
api_token.token = key
|
api_token.token = key
|
||||||
api_token.type = self.resource_type
|
api_token.type = self.resource_type
|
||||||
db.session.add(api_token)
|
db.session.add(api_token)
|
||||||
@ -112,9 +108,8 @@ class BaseApiKeyResource(Resource):
|
|||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
api_key_id = str(api_key_id)
|
api_key_id = str(api_key_id)
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
@ -158,11 +153,6 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
|||||||
"""Create a new API key for an app"""
|
"""Create a new API key for an app"""
|
||||||
return super().post(resource_id)
|
return super().post(resource_id)
|
||||||
|
|
||||||
def after_request(self, resp):
|
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
|
||||||
return resp
|
|
||||||
|
|
||||||
resource_type = "app"
|
resource_type = "app"
|
||||||
resource_model = App
|
resource_model = App
|
||||||
resource_id_field = "app_id"
|
resource_id_field = "app_id"
|
||||||
@ -179,11 +169,6 @@ class AppApiKeyResource(BaseApiKeyResource):
|
|||||||
"""Delete an API key for an app"""
|
"""Delete an API key for an app"""
|
||||||
return super().delete(resource_id, api_key_id)
|
return super().delete(resource_id, api_key_id)
|
||||||
|
|
||||||
def after_request(self, resp):
|
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
|
||||||
return resp
|
|
||||||
|
|
||||||
resource_type = "app"
|
resource_type = "app"
|
||||||
resource_model = App
|
resource_model = App
|
||||||
resource_id_field = "app_id"
|
resource_id_field = "app_id"
|
||||||
@ -208,11 +193,6 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
|||||||
"""Create a new API key for a dataset"""
|
"""Create a new API key for a dataset"""
|
||||||
return super().post(resource_id)
|
return super().post(resource_id)
|
||||||
|
|
||||||
def after_request(self, resp):
|
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
|
||||||
return resp
|
|
||||||
|
|
||||||
resource_type = "dataset"
|
resource_type = "dataset"
|
||||||
resource_model = Dataset
|
resource_model = Dataset
|
||||||
resource_id_field = "dataset_id"
|
resource_id_field = "dataset_id"
|
||||||
@ -229,11 +209,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
|||||||
"""Delete an API key for a dataset"""
|
"""Delete an API key for a dataset"""
|
||||||
return super().delete(resource_id, api_key_id)
|
return super().delete(resource_id, api_key_id)
|
||||||
|
|
||||||
def after_request(self, resp):
|
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
|
||||||
return resp
|
|
||||||
|
|
||||||
resource_type = "dataset"
|
resource_type = "dataset"
|
||||||
resource_model = Dataset
|
resource_model = Dataset
|
||||||
resource_id_field = "dataset_id"
|
resource_id_field = "dataset_id"
|
||||||
|
|||||||
@ -25,11 +25,13 @@ class AdvancedPromptTemplateList(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("app_mode", type=str, required=True, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("model_mode", type=str, required=True, location="args")
|
.add_argument("app_mode", type=str, required=True, location="args")
|
||||||
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
|
.add_argument("model_mode", type=str, required=True, location="args")
|
||||||
parser.add_argument("model_name", type=str, required=True, location="args")
|
.add_argument("has_context", type=str, required=False, default="true", location="args")
|
||||||
|
.add_argument("model_name", type=str, required=True, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return AdvancedPromptTemplateService.get_prompt(args)
|
return AdvancedPromptTemplateService.get_prompt(args)
|
||||||
|
|||||||
@ -27,9 +27,11 @@ class AgentLogApi(Resource):
|
|||||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
"""Get agent logs"""
|
"""Get agent logs"""
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
|
.add_argument("message_id", type=uuid_value, required=True, location="args")
|
||||||
|
.add_argument("conversation_id", type=uuid_value, required=True, location="args")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,14 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
|
||||||
|
|
||||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
cloud_edition_billing_resource_check,
|
cloud_edition_billing_resource_check,
|
||||||
|
edit_permission_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
@ -42,15 +41,15 @@ class AnnotationReplyActionApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
.add_argument("score_threshold", required=True, type=float, location="json")
|
||||||
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
|
.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||||
|
.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if action == "enable":
|
if action == "enable":
|
||||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||||
@ -69,10 +68,8 @@ class AppAnnotationSettingDetailApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
|
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
|
||||||
return result, 200
|
return result, 200
|
||||||
@ -98,15 +95,12 @@ class AppAnnotationSettingUpdateApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_id, annotation_setting_id):
|
def post(self, app_id, annotation_setting_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_setting_id = str(annotation_setting_id)
|
annotation_setting_id = str(annotation_setting_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
|
||||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||||
@ -124,10 +118,8 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_id, job_id, action):
|
def get(self, app_id, job_id, action):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||||
cache_result = redis_client.get(app_annotation_job_key)
|
cache_result = redis_client.get(app_annotation_job_key)
|
||||||
@ -159,10 +151,8 @@ class AnnotationApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
keyword = request.args.get("keyword", default="", type=str)
|
keyword = request.args.get("keyword", default="", type=str)
|
||||||
@ -198,14 +188,14 @@ class AnnotationApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
@marshal_with(annotation_fields)
|
@marshal_with(annotation_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("question", required=True, type=str, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("answer", required=True, type=str, location="json")
|
.add_argument("question", required=True, type=str, location="json")
|
||||||
|
.add_argument("answer", required=True, type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||||
return annotation
|
return annotation
|
||||||
@ -213,10 +203,8 @@ class AnnotationApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def delete(self, app_id):
|
def delete(self, app_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
|
|
||||||
# Use request.args.getlist to get annotation_ids array directly
|
# Use request.args.getlist to get annotation_ids array directly
|
||||||
@ -249,10 +237,8 @@ class AnnotationExportApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||||
response = {"data": marshal(annotation_list, annotation_fields)}
|
response = {"data": marshal(annotation_list, annotation_fields)}
|
||||||
@ -271,16 +257,16 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
|
@edit_permission_required
|
||||||
@marshal_with(annotation_fields)
|
@marshal_with(annotation_fields)
|
||||||
def post(self, app_id, annotation_id):
|
def post(self, app_id, annotation_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_id = str(annotation_id)
|
annotation_id = str(annotation_id)
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("question", required=True, type=str, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("answer", required=True, type=str, location="json")
|
.add_argument("question", required=True, type=str, location="json")
|
||||||
|
.add_argument("answer", required=True, type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||||
return annotation
|
return annotation
|
||||||
@ -288,10 +274,8 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def delete(self, app_id, annotation_id):
|
def delete(self, app_id, annotation_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_id = str(annotation_id)
|
annotation_id = str(annotation_id)
|
||||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||||
@ -310,10 +294,8 @@ class AnnotationBatchImportApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
# check file
|
# check file
|
||||||
if "file" not in request.files:
|
if "file" not in request.files:
|
||||||
@ -341,10 +323,8 @@ class AnnotationBatchImportStatusApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_id, job_id):
|
def get(self, app_id, job_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
||||||
cache_result = redis_client.get(indexing_cache_key)
|
cache_result = redis_client.get(indexing_cache_key)
|
||||||
@ -376,10 +356,8 @@ class AnnotationHitHistoryListApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_id, annotation_id):
|
def get(self, app_id, annotation_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -12,15 +10,16 @@ from controllers.console.app.wraps import get_app_model
|
|||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
cloud_edition_billing_resource_check,
|
cloud_edition_billing_resource_check,
|
||||||
|
edit_permission_required,
|
||||||
enterprise_license_required,
|
enterprise_license_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from core.ops.ops_trace_manager import OpsTraceManager
|
from core.ops.ops_trace_manager import OpsTraceManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from libs.validators import validate_description_length
|
from libs.validators import validate_description_length
|
||||||
from models import Account, App
|
from models import App
|
||||||
from services.app_dsl_service import AppDslService, ImportMode
|
from services.app_dsl_service import AppDslService, ImportMode
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
@ -56,6 +55,7 @@ class AppListApi(Resource):
|
|||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Get app list"""
|
"""Get app list"""
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
def uuid_list(value):
|
def uuid_list(value):
|
||||||
try:
|
try:
|
||||||
@ -63,34 +63,36 @@ class AppListApi(Resource):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
abort(400, message="Invalid UUID format in tag_ids.")
|
abort(400, message="Invalid UUID format in tag_ids.")
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||||
parser.add_argument(
|
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
"mode",
|
.add_argument(
|
||||||
type=str,
|
"mode",
|
||||||
choices=[
|
type=str,
|
||||||
"completion",
|
choices=[
|
||||||
"chat",
|
"completion",
|
||||||
"advanced-chat",
|
"chat",
|
||||||
"workflow",
|
"advanced-chat",
|
||||||
"agent-chat",
|
"workflow",
|
||||||
"channel",
|
"agent-chat",
|
||||||
"all",
|
"channel",
|
||||||
],
|
"all",
|
||||||
default="all",
|
],
|
||||||
location="args",
|
default="all",
|
||||||
required=False,
|
location="args",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
.add_argument("name", type=str, location="args", required=False)
|
||||||
|
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||||
|
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||||
)
|
)
|
||||||
parser.add_argument("name", type=str, location="args", required=False)
|
|
||||||
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
|
||||||
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# get app list
|
# get app list
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
|
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
|
||||||
if not app_pagination:
|
if not app_pagination:
|
||||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||||
|
|
||||||
@ -129,30 +131,26 @@ class AppListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_detail_fields)
|
@marshal_with(app_detail_fields)
|
||||||
@cloud_edition_billing_resource_check("apps")
|
@cloud_edition_billing_resource_check("apps")
|
||||||
|
@edit_permission_required
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Create app"""
|
"""Create app"""
|
||||||
parser = reqparse.RequestParser()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
parser = (
|
||||||
parser.add_argument("description", type=validate_description_length, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
.add_argument("name", type=str, required=True, location="json")
|
||||||
parser.add_argument("icon_type", type=str, location="json")
|
.add_argument("description", type=validate_description_length, location="json")
|
||||||
parser.add_argument("icon", type=str, location="json")
|
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||||
parser.add_argument("icon_background", type=str, location="json")
|
.add_argument("icon_type", type=str, location="json")
|
||||||
|
.add_argument("icon", type=str, location="json")
|
||||||
|
.add_argument("icon_background", type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if "mode" not in args or args["mode"] is None:
|
if "mode" not in args or args["mode"] is None:
|
||||||
raise BadRequest("mode is required")
|
raise BadRequest("mode is required")
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
if not isinstance(current_user, Account):
|
app = app_service.create_app(current_tenant_id, args, current_user)
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
if current_user.current_tenant_id is None:
|
|
||||||
raise ValueError("current_user.current_tenant_id cannot be None")
|
|
||||||
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
|
|
||||||
|
|
||||||
return app, 201
|
return app, 201
|
||||||
|
|
||||||
@ -205,21 +203,20 @@ class AppApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
|
@edit_permission_required
|
||||||
@marshal_with(app_detail_fields_with_site)
|
@marshal_with(app_detail_fields_with_site)
|
||||||
def put(self, app_model):
|
def put(self, app_model):
|
||||||
"""Update app"""
|
"""Update app"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = (
|
||||||
if not current_user.is_editor:
|
reqparse.RequestParser()
|
||||||
raise Forbidden()
|
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("description", type=validate_description_length, location="json")
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("icon_type", type=str, location="json")
|
||||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
.add_argument("icon", type=str, location="json")
|
||||||
parser.add_argument("description", type=validate_description_length, location="json")
|
.add_argument("icon_background", type=str, location="json")
|
||||||
parser.add_argument("icon_type", type=str, location="json")
|
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||||
parser.add_argument("icon", type=str, location="json")
|
.add_argument("max_active_requests", type=int, location="json")
|
||||||
parser.add_argument("icon_background", type=str, location="json")
|
)
|
||||||
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
|
||||||
parser.add_argument("max_active_requests", type=int, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
@ -248,12 +245,9 @@ class AppApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def delete(self, app_model):
|
def delete(self, app_model):
|
||||||
"""Delete app"""
|
"""Delete app"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_service.delete_app(app_model)
|
app_service.delete_app(app_model)
|
||||||
|
|
||||||
@ -283,27 +277,28 @@ class AppCopyApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
|
@edit_permission_required
|
||||||
@marshal_with(app_detail_fields_with_site)
|
@marshal_with(app_detail_fields_with_site)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
"""Copy app"""
|
"""Copy app"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("name", type=str, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("description", type=validate_description_length, location="json")
|
.add_argument("name", type=str, location="json")
|
||||||
parser.add_argument("icon_type", type=str, location="json")
|
.add_argument("description", type=validate_description_length, location="json")
|
||||||
parser.add_argument("icon", type=str, location="json")
|
.add_argument("icon_type", type=str, location="json")
|
||||||
parser.add_argument("icon_background", type=str, location="json")
|
.add_argument("icon", type=str, location="json")
|
||||||
|
.add_argument("icon_background", type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = AppDslService(session)
|
import_service = AppDslService(session)
|
||||||
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
||||||
account = cast(Account, current_user)
|
|
||||||
result = import_service.import_app(
|
result = import_service.import_app(
|
||||||
account=account,
|
account=current_user,
|
||||||
import_mode=ImportMode.YAML_CONTENT,
|
import_mode=ImportMode.YAML_CONTENT,
|
||||||
yaml_content=yaml_content,
|
yaml_content=yaml_content,
|
||||||
name=args.get("name"),
|
name=args.get("name"),
|
||||||
@ -340,16 +335,15 @@ class AppExportApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
"""Export app"""
|
"""Export app"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# Add include_secret params
|
# Add include_secret params
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("workflow_id", type=str, location="args")
|
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||||
|
.add_argument("workflow_id", type=str, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -371,13 +365,9 @@ class AppNameApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(app_detail_fields)
|
@marshal_with(app_detail_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
@ -408,14 +398,13 @@ class AppIconApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(app_detail_fields)
|
@marshal_with(app_detail_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = (
|
||||||
if not current_user.is_editor:
|
reqparse.RequestParser()
|
||||||
raise Forbidden()
|
.add_argument("icon", type=str, location="json")
|
||||||
|
.add_argument("icon_background", type=str, location="json")
|
||||||
parser = reqparse.RequestParser()
|
)
|
||||||
parser.add_argument("icon", type=str, location="json")
|
|
||||||
parser.add_argument("icon_background", type=str, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
@ -441,13 +430,9 @@ class AppSiteStatus(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(app_detail_fields)
|
@marshal_with(app_detail_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("enable_site", type=bool, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
@ -475,11 +460,11 @@ class AppApiStatus(Resource):
|
|||||||
@marshal_with(app_detail_fields)
|
@marshal_with(app_detail_fields)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
|
||||||
parser.add_argument("enable_api", type=bool, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
@ -520,13 +505,14 @@ class AppTraceApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
# add app trace
|
# add app trace
|
||||||
if not current_user.is_editor:
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("enabled", type=bool, required=True, location="json")
|
||||||
parser.add_argument("enabled", type=bool, required=True, location="json")
|
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
OpsTraceManager.update_app_tracing_config(
|
OpsTraceManager.update_app_tracing_config(
|
||||||
|
|||||||
@ -1,20 +1,16 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
|
||||||
|
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
cloud_edition_billing_resource_check,
|
cloud_edition_billing_resource_check,
|
||||||
|
edit_permission_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account
|
|
||||||
from models.model import App
|
from models.model import App
|
||||||
from services.app_dsl_service import AppDslService, ImportStatus
|
from services.app_dsl_service import AppDslService, ImportStatus
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
@ -30,28 +26,29 @@ class AppImportApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_import_fields)
|
@marshal_with(app_import_fields)
|
||||||
@cloud_edition_billing_resource_check("apps")
|
@cloud_edition_billing_resource_check("apps")
|
||||||
|
@edit_permission_required
|
||||||
def post(self):
|
def post(self):
|
||||||
# Check user role first
|
# Check user role first
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = (
|
||||||
|
reqparse.RequestParser()
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("mode", type=str, required=True, location="json")
|
||||||
parser.add_argument("mode", type=str, required=True, location="json")
|
.add_argument("yaml_content", type=str, location="json")
|
||||||
parser.add_argument("yaml_content", type=str, location="json")
|
.add_argument("yaml_url", type=str, location="json")
|
||||||
parser.add_argument("yaml_url", type=str, location="json")
|
.add_argument("name", type=str, location="json")
|
||||||
parser.add_argument("name", type=str, location="json")
|
.add_argument("description", type=str, location="json")
|
||||||
parser.add_argument("description", type=str, location="json")
|
.add_argument("icon_type", type=str, location="json")
|
||||||
parser.add_argument("icon_type", type=str, location="json")
|
.add_argument("icon", type=str, location="json")
|
||||||
parser.add_argument("icon", type=str, location="json")
|
.add_argument("icon_background", type=str, location="json")
|
||||||
parser.add_argument("icon_background", type=str, location="json")
|
.add_argument("app_id", type=str, location="json")
|
||||||
parser.add_argument("app_id", type=str, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Create service with session
|
# Create service with session
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = AppDslService(session)
|
import_service = AppDslService(session)
|
||||||
# Import app
|
# Import app
|
||||||
account = cast(Account, current_user)
|
account = current_user
|
||||||
result = import_service.import_app(
|
result = import_service.import_app(
|
||||||
account=account,
|
account=account,
|
||||||
import_mode=args["mode"],
|
import_mode=args["mode"],
|
||||||
@ -83,16 +80,16 @@ class AppImportConfirmApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_import_fields)
|
@marshal_with(app_import_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, import_id):
|
def post(self, import_id):
|
||||||
# Check user role first
|
# Check user role first
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# Create service with session
|
# Create service with session
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = AppDslService(session)
|
import_service = AppDslService(session)
|
||||||
# Confirm import
|
# Confirm import
|
||||||
account = cast(Account, current_user)
|
account = current_user
|
||||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
@ -109,10 +106,8 @@ class AppImportCheckDependenciesApi(Resource):
|
|||||||
@get_app_model
|
@get_app_model
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_import_check_dependencies_fields)
|
@marshal_with(app_import_check_dependencies_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = AppDslService(session)
|
import_service = AppDslService(session)
|
||||||
result = import_service.check_dependencies(app_model=app_model)
|
result = import_service.check_dependencies(app_model=app_model)
|
||||||
|
|||||||
@ -111,11 +111,13 @@ class ChatMessageTextApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("message_id", type=str, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("text", type=str, location="json")
|
.add_argument("message_id", type=str, location="json")
|
||||||
parser.add_argument("voice", type=str, location="json")
|
.add_argument("text", type=str, location="json")
|
||||||
parser.add_argument("streaming", type=bool, location="json")
|
.add_argument("voice", type=str, location="json")
|
||||||
|
.add_argument("streaming", type=bool, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = args.get("message_id", None)
|
message_id = args.get("message_id", None)
|
||||||
@ -166,8 +168,7 @@ class TextModesApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
|
||||||
parser.add_argument("language", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
response = AudioService.transcript_tts_voices(
|
response = AudioService.transcript_tts_voices(
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import logging
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
@ -15,7 +15,7 @@ from controllers.console.app.error import (
|
|||||||
ProviderQuotaExceededError,
|
ProviderQuotaExceededError,
|
||||||
)
|
)
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
@ -64,13 +64,15 @@ class CompletionMessageApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("query", type=str, location="json", default="")
|
.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
.add_argument("query", type=str, location="json", default="")
|
||||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
.add_argument("model_config", type=dict, required=True, location="json")
|
||||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
|
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args["response_mode"] != "blocking"
|
streaming = args["response_mode"] != "blocking"
|
||||||
@ -151,22 +153,19 @@ class ChatMessageApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
if not isinstance(current_user, Account):
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
|
.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
if not current_user.has_edit_permission:
|
.add_argument("query", type=str, required=True, location="json")
|
||||||
raise Forbidden()
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
|
.add_argument("model_config", type=dict, required=True, location="json")
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||||
parser.add_argument("query", type=str, required=True, location="json")
|
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
)
|
||||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
|
||||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
|
||||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
|
||||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args["response_mode"] != "blocking"
|
streaming = args["response_mode"] != "blocking"
|
||||||
|
|||||||
@ -1,17 +1,16 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import pytz # pip install pytz
|
import pytz
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from sqlalchemy import func, or_
|
from sqlalchemy import func, or_
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.conversation_fields import (
|
from fields.conversation_fields import (
|
||||||
@ -22,8 +21,8 @@ from fields.conversation_fields import (
|
|||||||
)
|
)
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import DatetimeString
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account, Conversation, EndUser, Message, MessageAnnotation
|
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
@ -57,18 +56,24 @@ class CompletionConversationApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
@marshal_with(conversation_pagination_fields)
|
@marshal_with(conversation_pagination_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = (
|
||||||
parser = reqparse.RequestParser()
|
reqparse.RequestParser()
|
||||||
parser.add_argument("keyword", type=str, location="args")
|
.add_argument("keyword", type=str, location="args")
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument(
|
.add_argument(
|
||||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
"annotation_status",
|
||||||
|
type=str,
|
||||||
|
choices=["annotated", "not_annotated", "all"],
|
||||||
|
default="all",
|
||||||
|
location="args",
|
||||||
|
)
|
||||||
|
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||||
|
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||||
)
|
)
|
||||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
|
||||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
query = sa.select(Conversation).where(
|
query = sa.select(Conversation).where(
|
||||||
@ -84,6 +89,7 @@ class CompletionConversationApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
account = current_user
|
account = current_user
|
||||||
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
@ -137,9 +143,8 @@ class CompletionConversationDetailApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
@marshal_with(conversation_message_detail_fields)
|
@marshal_with(conversation_message_detail_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model, conversation_id):
|
def get(self, app_model, conversation_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
return _get_conversation(app_model, conversation_id)
|
return _get_conversation(app_model, conversation_id)
|
||||||
@ -154,14 +159,12 @@ class CompletionConversationDetailApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
|
@edit_permission_required
|
||||||
def delete(self, app_model, conversation_id):
|
def delete(self, app_model, conversation_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
ConversationService.delete(app_model, conversation_id, current_user)
|
ConversationService.delete(app_model, conversation_id, current_user)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
@ -206,26 +209,32 @@ class ChatConversationApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
@marshal_with(conversation_with_summary_pagination_fields)
|
@marshal_with(conversation_with_summary_pagination_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = (
|
||||||
parser = reqparse.RequestParser()
|
reqparse.RequestParser()
|
||||||
parser.add_argument("keyword", type=str, location="args")
|
.add_argument("keyword", type=str, location="args")
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument(
|
.add_argument(
|
||||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
"annotation_status",
|
||||||
)
|
type=str,
|
||||||
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
choices=["annotated", "not_annotated", "all"],
|
||||||
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
default="all",
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
location="args",
|
||||||
parser.add_argument(
|
)
|
||||||
"sort_by",
|
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
||||||
type=str,
|
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
||||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
required=False,
|
.add_argument(
|
||||||
default="-updated_at",
|
"sort_by",
|
||||||
location="args",
|
type=str,
|
||||||
|
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||||
|
required=False,
|
||||||
|
default="-updated_at",
|
||||||
|
location="args",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -260,6 +269,7 @@ class ChatConversationApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
account = current_user
|
account = current_user
|
||||||
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
@ -341,9 +351,8 @@ class ChatConversationDetailApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
@marshal_with(conversation_detail_fields)
|
@marshal_with(conversation_detail_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model, conversation_id):
|
def get(self, app_model, conversation_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
return _get_conversation(app_model, conversation_id)
|
return _get_conversation(app_model, conversation_id)
|
||||||
@ -358,14 +367,12 @@ class ChatConversationDetailApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def delete(self, app_model, conversation_id):
|
def delete(self, app_model, conversation_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
ConversationService.delete(app_model, conversation_id, current_user)
|
ConversationService.delete(app_model, conversation_id, current_user)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
@ -374,6 +381,7 @@ class ChatConversationDetailApi(Resource):
|
|||||||
|
|
||||||
|
|
||||||
def _get_conversation(app_model, conversation_id):
|
def _get_conversation(app_model, conversation_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
conversation = (
|
conversation = (
|
||||||
db.session.query(Conversation)
|
db.session.query(Conversation)
|
||||||
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||||
|
|||||||
@ -29,8 +29,7 @@ class ConversationVariablesApi(Resource):
|
|||||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||||
@marshal_with(paginated_conversation_variable_fields)
|
@marshal_with(paginated_conversation_variable_fields)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
|
||||||
parser.add_argument("conversation_id", type=str, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields, reqparse
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
@ -17,7 +16,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
|
|||||||
from core.llm_generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import App
|
from models import App
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
@ -43,16 +42,18 @@ class RuleGenerateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
account = current_user
|
|
||||||
try:
|
try:
|
||||||
rules = LLMGenerator.generate_rule_config(
|
rules = LLMGenerator.generate_rule_config(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
no_variable=args["no_variable"],
|
no_variable=args["no_variable"],
|
||||||
@ -93,17 +94,19 @@ class RuleCodeGenerateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||||
|
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
account = current_user
|
|
||||||
try:
|
try:
|
||||||
code_result = LLMGenerator.generate_code(
|
code_result = LLMGenerator.generate_code(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
code_language=args["code_language"],
|
code_language=args["code_language"],
|
||||||
@ -140,15 +143,17 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
account = current_user
|
|
||||||
try:
|
try:
|
||||||
structured_output = LLMGenerator.generate_structured_output(
|
structured_output = LLMGenerator.generate_structured_output(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
)
|
)
|
||||||
@ -189,15 +194,18 @@ class InstructionGenerateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("flow_id", type=str, required=True, default="", location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("node_id", type=str, required=False, default="", location="json")
|
.add_argument("flow_id", type=str, required=True, default="", location="json")
|
||||||
parser.add_argument("current", type=str, required=False, default="", location="json")
|
.add_argument("node_id", type=str, required=False, default="", location="json")
|
||||||
parser.add_argument("language", type=str, required=False, default="javascript", location="json")
|
.add_argument("current", type=str, required=False, default="", location="json")
|
||||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
.add_argument("language", type=str, required=False, default="javascript", location="json")
|
||||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
code_template = (
|
code_template = (
|
||||||
Python3CodeProvider.get_default_code()
|
Python3CodeProvider.get_default_code()
|
||||||
if args["language"] == "python"
|
if args["language"] == "python"
|
||||||
@ -222,21 +230,21 @@ class InstructionGenerateApi(Resource):
|
|||||||
match node_type:
|
match node_type:
|
||||||
case "llm":
|
case "llm":
|
||||||
return LLMGenerator.generate_rule_config(
|
return LLMGenerator.generate_rule_config(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
no_variable=True,
|
no_variable=True,
|
||||||
)
|
)
|
||||||
case "agent":
|
case "agent":
|
||||||
return LLMGenerator.generate_rule_config(
|
return LLMGenerator.generate_rule_config(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
no_variable=True,
|
no_variable=True,
|
||||||
)
|
)
|
||||||
case "code":
|
case "code":
|
||||||
return LLMGenerator.generate_code(
|
return LLMGenerator.generate_code(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
code_language=args["language"],
|
code_language=args["language"],
|
||||||
@ -245,7 +253,7 @@ class InstructionGenerateApi(Resource):
|
|||||||
return {"error": f"invalid node type: {node_type}"}
|
return {"error": f"invalid node type: {node_type}"}
|
||||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
||||||
return LLMGenerator.instruction_modify_legacy(
|
return LLMGenerator.instruction_modify_legacy(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
flow_id=args["flow_id"],
|
flow_id=args["flow_id"],
|
||||||
current=args["current"],
|
current=args["current"],
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
@ -254,7 +262,7 @@ class InstructionGenerateApi(Resource):
|
|||||||
)
|
)
|
||||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
||||||
return LLMGenerator.instruction_modify_workflow(
|
return LLMGenerator.instruction_modify_workflow(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
flow_id=args["flow_id"],
|
flow_id=args["flow_id"],
|
||||||
node_id=args["node_id"],
|
node_id=args["node_id"],
|
||||||
current=args["current"],
|
current=args["current"],
|
||||||
@ -293,8 +301,7 @@ class InstructionGenerationTemplateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
|
||||||
parser.add_argument("type", type=str, required=True, default=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
match args["type"]:
|
match args["type"]:
|
||||||
case "prompt":
|
case "prompt":
|
||||||
|
|||||||
@ -1,16 +1,15 @@
|
|||||||
import json
|
import json
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import app_server_fields
|
from fields.app_fields import app_server_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.model import AppMCPServer
|
from models.model import AppMCPServer
|
||||||
|
|
||||||
|
|
||||||
@ -25,9 +24,9 @@ class AppMCPServerController(Resource):
|
|||||||
@api.doc(description="Get MCP server configuration for an application")
|
@api.doc(description="Get MCP server configuration for an application")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
|
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
|
||||||
@setup_required
|
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@setup_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(app_server_fields)
|
@marshal_with(app_server_fields)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
@ -48,17 +47,19 @@ class AppMCPServerController(Resource):
|
|||||||
)
|
)
|
||||||
@api.response(201, "MCP server configuration created successfully", app_server_fields)
|
@api.response(201, "MCP server configuration created successfully", app_server_fields)
|
||||||
@api.response(403, "Insufficient permissions")
|
@api.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
|
@login_required
|
||||||
|
@setup_required
|
||||||
@marshal_with(app_server_fields)
|
@marshal_with(app_server_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
if not current_user.is_editor:
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise NotFound()
|
parser = (
|
||||||
parser = reqparse.RequestParser()
|
reqparse.RequestParser()
|
||||||
parser.add_argument("description", type=str, required=False, location="json")
|
.add_argument("description", type=str, required=False, location="json")
|
||||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
.add_argument("parameters", type=dict, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
description = args.get("description")
|
description = args.get("description")
|
||||||
@ -71,7 +72,7 @@ class AppMCPServerController(Resource):
|
|||||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||||
status=AppMCPServerStatus.ACTIVE,
|
status=AppMCPServerStatus.ACTIVE,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
server_code=AppMCPServer.generate_server_code(16),
|
server_code=AppMCPServer.generate_server_code(16),
|
||||||
)
|
)
|
||||||
db.session.add(server)
|
db.session.add(server)
|
||||||
@ -95,19 +96,20 @@ class AppMCPServerController(Resource):
|
|||||||
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
|
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
|
||||||
@api.response(403, "Insufficient permissions")
|
@api.response(403, "Insufficient permissions")
|
||||||
@api.response(404, "Server not found")
|
@api.response(404, "Server not found")
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
@get_app_model
|
@get_app_model
|
||||||
|
@login_required
|
||||||
|
@setup_required
|
||||||
|
@account_initialization_required
|
||||||
@marshal_with(app_server_fields)
|
@marshal_with(app_server_fields)
|
||||||
|
@edit_permission_required
|
||||||
def put(self, app_model):
|
def put(self, app_model):
|
||||||
if not current_user.is_editor:
|
parser = (
|
||||||
raise NotFound()
|
reqparse.RequestParser()
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("id", type=str, required=True, location="json")
|
||||||
parser.add_argument("id", type=str, required=True, location="json")
|
.add_argument("description", type=str, required=False, location="json")
|
||||||
parser.add_argument("description", type=str, required=False, location="json")
|
.add_argument("parameters", type=dict, required=True, location="json")
|
||||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
.add_argument("status", type=str, required=False, location="json")
|
||||||
parser.add_argument("status", type=str, required=False, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
|
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
|
||||||
if not server:
|
if not server:
|
||||||
@ -142,13 +144,13 @@ class AppMCPServerRefreshController(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_server_fields)
|
@marshal_with(app_server_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, server_id):
|
def get(self, server_id):
|
||||||
if not current_user.is_editor:
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise NotFound()
|
|
||||||
server = (
|
server = (
|
||||||
db.session.query(AppMCPServer)
|
db.session.query(AppMCPServer)
|
||||||
.where(AppMCPServer.id == server_id)
|
.where(AppMCPServer.id == server_id)
|
||||||
.where(AppMCPServer.tenant_id == current_user.current_tenant_id)
|
.where(AppMCPServer.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not server:
|
if not server:
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import logging
|
|||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from sqlalchemy import exists, select
|
from sqlalchemy import exists, select
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
@ -17,6 +17,7 @@ from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDi
|
|||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
cloud_edition_billing_resource_check,
|
cloud_edition_billing_resource_check,
|
||||||
|
edit_permission_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
@ -26,8 +27,7 @@ from extensions.ext_database import db
|
|||||||
from fields.conversation_fields import annotation_fields, message_detail_fields
|
from fields.conversation_fields import annotation_fields, message_detail_fields
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||||
from services.annotation_service import AppAnnotationService
|
from services.annotation_service import AppAnnotationService
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
@ -56,19 +56,19 @@ class ChatMessageListApi(Resource):
|
|||||||
)
|
)
|
||||||
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
|
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
|
||||||
@api.response(404, "Conversation not found")
|
@api.response(404, "Conversation not found")
|
||||||
@setup_required
|
|
||||||
@login_required
|
@login_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@setup_required
|
||||||
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
|
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("first_id", type=uuid_value, location="args")
|
||||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
)
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
conversation = (
|
conversation = (
|
||||||
@ -154,12 +154,13 @@ class MessageFeedbackApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
if current_user is None:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||||
|
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = str(args["message_id"])
|
message_id = str(args["message_id"])
|
||||||
@ -211,23 +212,21 @@ class MessageAnnotationApi(Resource):
|
|||||||
)
|
)
|
||||||
@api.response(200, "Annotation created successfully", annotation_fields)
|
@api.response(200, "Annotation created successfully", annotation_fields)
|
||||||
@api.response(403, "Insufficient permissions")
|
@api.response(403, "Insufficient permissions")
|
||||||
|
@marshal_with(annotation_fields)
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
@get_app_model
|
@account_initialization_required
|
||||||
@marshal_with(annotation_fields)
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
if not isinstance(current_user, Account):
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
if not current_user.has_edit_permission:
|
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||||
raise Forbidden()
|
.add_argument("question", required=True, type=str, location="json")
|
||||||
|
.add_argument("answer", required=True, type=str, location="json")
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||||
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
|
)
|
||||||
parser.add_argument("question", required=True, type=str, location="json")
|
|
||||||
parser.add_argument("answer", required=True, type=str, location="json")
|
|
||||||
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
||||||
|
|
||||||
@ -270,6 +269,7 @@ class MessageSuggestedQuestionApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
def get(self, app_model, message_id):
|
def get(self, app_model, message_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -304,12 +304,12 @@ class MessageApi(Resource):
|
|||||||
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||||
@api.response(200, "Message retrieved successfully", message_detail_fields)
|
@api.response(200, "Message retrieved successfully", message_detail_fields)
|
||||||
@api.response(404, "Message not found")
|
@api.response(404, "Message not found")
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
@marshal_with(message_detail_fields)
|
@marshal_with(message_detail_fields)
|
||||||
def get(self, app_model, message_id):
|
def get(self, app_model, message_id: str):
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import json
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -15,8 +14,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
|
|||||||
from events.app_event import app_model_config_was_updated
|
from events.app_event import app_model_config_was_updated
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from models.model import AppMode, AppModelConfig
|
from models.model import AppMode, AppModelConfig
|
||||||
from services.app_model_config_service import AppModelConfigService
|
from services.app_model_config_service import AppModelConfigService
|
||||||
|
|
||||||
@ -54,16 +52,14 @@ class ModelConfigResource(Resource):
|
|||||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
"""Modify app model config"""
|
"""Modify app model config"""
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
|
||||||
# validate config
|
# validate config
|
||||||
model_configuration = AppModelConfigService.validate_configuration(
|
model_configuration = AppModelConfigService.validate_configuration(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
config=cast(dict, request.json),
|
config=cast(dict, request.json),
|
||||||
app_mode=AppMode.value_of(app_model.mode),
|
app_mode=AppMode.value_of(app_model.mode),
|
||||||
)
|
)
|
||||||
@ -95,12 +91,12 @@ class ModelConfigResource(Resource):
|
|||||||
# get tool
|
# get tool
|
||||||
try:
|
try:
|
||||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
agent_tool=agent_tool_entity,
|
agent_tool=agent_tool_entity,
|
||||||
)
|
)
|
||||||
manager = ToolParameterConfigurationManager(
|
manager = ToolParameterConfigurationManager(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
tool_runtime=tool_runtime,
|
tool_runtime=tool_runtime,
|
||||||
provider_name=agent_tool_entity.provider_id,
|
provider_name=agent_tool_entity.provider_id,
|
||||||
provider_type=agent_tool_entity.provider_type,
|
provider_type=agent_tool_entity.provider_type,
|
||||||
@ -134,7 +130,7 @@ class ModelConfigResource(Resource):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
agent_tool=agent_tool_entity,
|
agent_tool=agent_tool_entity,
|
||||||
)
|
)
|
||||||
@ -142,7 +138,7 @@ class ModelConfigResource(Resource):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
manager = ToolParameterConfigurationManager(
|
manager = ToolParameterConfigurationManager(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
tool_runtime=tool_runtime,
|
tool_runtime=tool_runtime,
|
||||||
provider_name=agent_tool_entity.provider_id,
|
provider_name=agent_tool_entity.provider_id,
|
||||||
provider_type=agent_tool_entity.provider_type,
|
provider_type=agent_tool_entity.provider_type,
|
||||||
|
|||||||
@ -30,8 +30,7 @@ class TraceAppConfigApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -63,9 +62,11 @@ class TraceAppConfigApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
"""Create a new trace app configuration"""
|
"""Create a new trace app configuration"""
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||||
|
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -99,9 +100,11 @@ class TraceAppConfigApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, app_id):
|
def patch(self, app_id):
|
||||||
"""Update an existing trace app configuration"""
|
"""Update an existing trace app configuration"""
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||||
|
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -129,8 +132,7 @@ class TraceAppConfigApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, app_id):
|
def delete(self, app_id):
|
||||||
"""Delete an existing trace app configuration"""
|
"""Delete an existing trace app configuration"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
@ -9,30 +8,36 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import app_site_fields
|
from fields.app_fields import app_site_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account, Site
|
from models import Site
|
||||||
|
|
||||||
|
|
||||||
def parse_app_site_args():
|
def parse_app_site_args():
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("title", type=str, required=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("icon_type", type=str, required=False, location="json")
|
.add_argument("title", type=str, required=False, location="json")
|
||||||
parser.add_argument("icon", type=str, required=False, location="json")
|
.add_argument("icon_type", type=str, required=False, location="json")
|
||||||
parser.add_argument("icon_background", type=str, required=False, location="json")
|
.add_argument("icon", type=str, required=False, location="json")
|
||||||
parser.add_argument("description", type=str, required=False, location="json")
|
.add_argument("icon_background", type=str, required=False, location="json")
|
||||||
parser.add_argument("default_language", type=supported_language, required=False, location="json")
|
.add_argument("description", type=str, required=False, location="json")
|
||||||
parser.add_argument("chat_color_theme", type=str, required=False, location="json")
|
.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||||
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||||
parser.add_argument("customize_domain", type=str, required=False, location="json")
|
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||||
parser.add_argument("copyright", type=str, required=False, location="json")
|
.add_argument("customize_domain", type=str, required=False, location="json")
|
||||||
parser.add_argument("privacy_policy", type=str, required=False, location="json")
|
.add_argument("copyright", type=str, required=False, location="json")
|
||||||
parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||||
parser.add_argument(
|
.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||||
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
|
.add_argument(
|
||||||
|
"customize_token_strategy",
|
||||||
|
type=str,
|
||||||
|
choices=["must", "allow", "not_allow"],
|
||||||
|
required=False,
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||||
|
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||||
|
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("prompt_public", type=bool, required=False, location="json")
|
|
||||||
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
|
||||||
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -76,9 +81,10 @@ class AppSite(Resource):
|
|||||||
@marshal_with(app_site_fields)
|
@marshal_with(app_site_fields)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
args = parse_app_site_args()
|
args = parse_app_site_args()
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
# The role of the current user in the ta table must be editor, admin, or owner
|
# The role of the current user in the ta table must be editor, admin, or owner
|
||||||
if not current_user.is_editor:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||||
@ -107,8 +113,6 @@ class AppSite(Resource):
|
|||||||
if value is not None:
|
if value is not None:
|
||||||
setattr(site, attr_name, value)
|
setattr(site, attr_name, value)
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
site.updated_by = current_user.id
|
site.updated_by = current_user.id
|
||||||
site.updated_at = naive_utc_now()
|
site.updated_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -131,6 +135,8 @@ class AppSiteAccessTokenReset(Resource):
|
|||||||
@marshal_with(app_site_fields)
|
@marshal_with(app_site_fields)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
@ -140,8 +146,6 @@ class AppSiteAccessTokenReset(Resource):
|
|||||||
raise NotFound
|
raise NotFound
|
||||||
|
|
||||||
site.code = Site.generate_code(16)
|
site.code = Site.generate_code(16)
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
site.updated_by = current_user.id
|
site.updated_by = current_user.id
|
||||||
site.updated_at = naive_utc_now()
|
site.updated_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from decimal import Decimal
|
|||||||
import pytz
|
import pytz
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields, reqparse
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
@ -13,7 +12,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import DatetimeString
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import AppMode, Message
|
from models import AppMode, Message
|
||||||
|
|
||||||
|
|
||||||
@ -37,11 +36,13 @@ class DailyMessageStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -53,6 +54,7 @@ WHERE
|
|||||||
app_id = :app_id
|
app_id = :app_id
|
||||||
AND invoke_from != :invoke_from"""
|
AND invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
@ -109,13 +111,15 @@ class DailyConversationStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
@ -175,11 +179,13 @@ class DailyTerminalsStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -191,7 +197,7 @@ WHERE
|
|||||||
app_id = :app_id
|
app_id = :app_id
|
||||||
AND invoke_from != :invoke_from"""
|
AND invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
@ -247,11 +253,13 @@ class DailyTokenCostStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -264,7 +272,7 @@ WHERE
|
|||||||
app_id = :app_id
|
app_id = :app_id
|
||||||
AND invoke_from != :invoke_from"""
|
AND invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
@ -322,11 +330,13 @@ class AverageSessionInteractionStatistic(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -346,7 +356,7 @@ FROM
|
|||||||
c.app_id = :app_id
|
c.app_id = :app_id
|
||||||
AND m.invoke_from != :invoke_from"""
|
AND m.invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
@ -413,11 +423,13 @@ class UserSatisfactionRateStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -433,7 +445,7 @@ WHERE
|
|||||||
m.app_id = :app_id
|
m.app_id = :app_id
|
||||||
AND m.invoke_from != :invoke_from"""
|
AND m.invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
@ -494,11 +506,13 @@ class AverageResponseTimeStatistic(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -510,7 +524,7 @@ WHERE
|
|||||||
app_id = :app_id
|
app_id = :app_id
|
||||||
AND invoke_from != :invoke_from"""
|
AND invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
@ -566,11 +580,13 @@ class TokensPerSecondStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -585,7 +601,7 @@ WHERE
|
|||||||
app_id = :app_id
|
app_id = :app_id
|
||||||
AND invoke_from != :invoke_from"""
|
AND invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import services
|
|||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
@ -27,9 +27,8 @@ from fields.workflow_run_fields import workflow_run_node_execution_fields
|
|||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import App
|
from models import App
|
||||||
from models.account import Account
|
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
@ -70,15 +69,11 @@ class DraftWorkflowApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@marshal_with(workflow_fields)
|
@marshal_with(workflow_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Get draft workflow
|
Get draft workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# fetch draft workflow by app_model
|
# fetch draft workflow by app_model
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||||
@ -110,24 +105,24 @@ class DraftWorkflowApi(Resource):
|
|||||||
@api.response(200, "Draft workflow synced successfully", workflow_fields)
|
@api.response(200, "Draft workflow synced successfully", workflow_fields)
|
||||||
@api.response(400, "Invalid workflow configuration")
|
@api.response(400, "Invalid workflow configuration")
|
||||||
@api.response(403, "Permission denied")
|
@api.response(403, "Permission denied")
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Sync draft workflow
|
Sync draft workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
current_user, _ = current_account_with_tenant()
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
content_type = request.headers.get("Content-Type", "")
|
content_type = request.headers.get("Content-Type", "")
|
||||||
|
|
||||||
if "application/json" in content_type:
|
if "application/json" in content_type:
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("hash", type=str, required=False, location="json")
|
.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("environment_variables", type=list, required=True, location="json")
|
.add_argument("hash", type=str, required=False, location="json")
|
||||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
.add_argument("environment_variables", type=list, required=True, location="json")
|
||||||
|
.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
elif "text/plain" in content_type:
|
elif "text/plain" in content_type:
|
||||||
try:
|
try:
|
||||||
@ -149,10 +144,6 @@ class DraftWorkflowApi(Resource):
|
|||||||
return {"message": "Invalid JSON data"}, 400
|
return {"message": "Invalid JSON data"}, 400
|
||||||
else:
|
else:
|
||||||
abort(415)
|
abort(415)
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -206,24 +197,21 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Run draft workflow
|
Run draft workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
current_user, _ = current_account_with_tenant()
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
|
.add_argument("inputs", type=dict, location="json")
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("query", type=str, required=True, location="json", default="")
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
.add_argument("files", type=list, location="json")
|
||||||
parser.add_argument("query", type=str, required=True, location="json", default="")
|
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
parser.add_argument("files", type=list, location="json")
|
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
)
|
||||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -271,18 +259,13 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App, node_id: str):
|
def post(self, app_model: App, node_id: str):
|
||||||
"""
|
"""
|
||||||
Run draft workflow iteration node
|
Run draft workflow iteration node
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -323,18 +306,13 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App, node_id: str):
|
def post(self, app_model: App, node_id: str):
|
||||||
"""
|
"""
|
||||||
Run draft workflow iteration node
|
Run draft workflow iteration node
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||||
raise Forbidden()
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -375,19 +353,13 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App, node_id: str):
|
def post(self, app_model: App, node_id: str):
|
||||||
"""
|
"""
|
||||||
Run draft workflow loop node
|
Run draft workflow loop node
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -428,19 +400,13 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App, node_id: str):
|
def post(self, app_model: App, node_id: str):
|
||||||
"""
|
"""
|
||||||
Run draft workflow loop node
|
Run draft workflow loop node
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -480,20 +446,17 @@ class DraftWorkflowRunApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Run draft workflow
|
Run draft workflow
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
if not current_user.has_edit_permission:
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
raise Forbidden()
|
)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
external_trace_id = get_external_trace_id(request)
|
external_trace_id = get_external_trace_id(request)
|
||||||
@ -526,17 +489,11 @@ class WorkflowTaskStopApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App, task_id: str):
|
def post(self, app_model: App, task_id: str):
|
||||||
"""
|
"""
|
||||||
Stop workflow task
|
Stop workflow task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# Stop using both mechanisms for backward compatibility
|
# Stop using both mechanisms for backward compatibility
|
||||||
# Legacy stop flag mechanism (without user check)
|
# Legacy stop flag mechanism (without user check)
|
||||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||||
@ -568,21 +525,18 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@marshal_with(workflow_run_node_execution_fields)
|
@marshal_with(workflow_run_node_execution_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App, node_id: str):
|
def post(self, app_model: App, node_id: str):
|
||||||
"""
|
"""
|
||||||
Run draft workflow node
|
Run draft workflow node
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
if not current_user.has_edit_permission:
|
.add_argument("query", type=str, required=False, location="json", default="")
|
||||||
raise Forbidden()
|
.add_argument("files", type=list, location="json", default=[])
|
||||||
|
)
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
|
||||||
parser.add_argument("query", type=str, required=False, location="json", default="")
|
|
||||||
parser.add_argument("files", type=list, location="json", default=[])
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
user_inputs = args.get("inputs")
|
user_inputs = args.get("inputs")
|
||||||
@ -622,17 +576,11 @@ class PublishedWorkflowApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@marshal_with(workflow_fields)
|
@marshal_with(workflow_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Get published workflow
|
Get published workflow
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# fetch published workflow by app_model
|
# fetch published workflow by app_model
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
workflow = workflow_service.get_published_workflow(app_model=app_model)
|
workflow = workflow_service.get_published_workflow(app_model=app_model)
|
||||||
@ -644,19 +592,17 @@ class PublishedWorkflowApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Publish workflow
|
Publish workflow
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = (
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
reqparse.RequestParser()
|
||||||
if not current_user.has_edit_permission:
|
.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||||
raise Forbidden()
|
.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||||
|
)
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
|
|
||||||
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate name and comment length
|
# Validate name and comment length
|
||||||
@ -702,17 +648,11 @@ class DefaultBlockConfigsApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Get default block config
|
Get default block config
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# Get default block configs
|
# Get default block configs
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
return workflow_service.get_default_block_configs()
|
return workflow_service.get_default_block_configs()
|
||||||
@ -729,18 +669,12 @@ class DefaultBlockConfigApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model: App, block_type: str):
|
def get(self, app_model: App, block_type: str):
|
||||||
"""
|
"""
|
||||||
Get default block config
|
Get default block config
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account):
|
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("q", type=str, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
q = args.get("q")
|
q = args.get("q")
|
||||||
@ -769,24 +703,23 @@ class ConvertToWorkflowApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Convert basic mode of chatbot app to workflow mode
|
Convert basic mode of chatbot app to workflow mode
|
||||||
Convert expert mode of chatbot app to workflow mode
|
Convert expert mode of chatbot app to workflow mode
|
||||||
Convert Completion App to Workflow App
|
Convert Completion App to Workflow App
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if request.data:
|
if request.data:
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
args = {}
|
args = {}
|
||||||
@ -812,21 +745,20 @@ class PublishedAllWorkflowApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@marshal_with(workflow_pagination_fields)
|
@marshal_with(workflow_pagination_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Get published workflows
|
Get published workflows
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
if not current_user.has_edit_permission:
|
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||||
raise Forbidden()
|
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
.add_argument("user_id", type=str, required=False, location="args")
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
)
|
||||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
parser.add_argument("user_id", type=str, required=False, location="args")
|
|
||||||
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
page = int(args.get("page", 1))
|
page = int(args.get("page", 1))
|
||||||
limit = int(args.get("limit", 10))
|
limit = int(args.get("limit", 10))
|
||||||
@ -879,19 +811,17 @@ class WorkflowByIdApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@marshal_with(workflow_fields)
|
@marshal_with(workflow_fields)
|
||||||
|
@edit_permission_required
|
||||||
def patch(self, app_model: App, workflow_id: str):
|
def patch(self, app_model: App, workflow_id: str):
|
||||||
"""
|
"""
|
||||||
Update workflow attributes
|
Update workflow attributes
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = (
|
||||||
# Check permission
|
reqparse.RequestParser()
|
||||||
if not current_user.has_edit_permission:
|
.add_argument("marked_name", type=str, required=False, location="json")
|
||||||
raise Forbidden()
|
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("marked_name", type=str, required=False, location="json")
|
|
||||||
parser.add_argument("marked_comment", type=str, required=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate name and comment length
|
# Validate name and comment length
|
||||||
@ -934,16 +864,11 @@ class WorkflowByIdApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def delete(self, app_model: App, workflow_id: str):
|
def delete(self, app_model: App, workflow_id: str):
|
||||||
"""
|
"""
|
||||||
Delete workflow
|
Delete workflow
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
# Check permission
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
|
|
||||||
# Create a session and manage the transaction
|
# Create a session and manage the transaction
|
||||||
|
|||||||
@ -42,33 +42,35 @@ class WorkflowAppLogApi(Resource):
|
|||||||
"""
|
"""
|
||||||
Get workflow app logs
|
Get workflow app logs
|
||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("keyword", type=str, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("keyword", type=str, location="args")
|
||||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
.add_argument(
|
||||||
|
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||||
|
)
|
||||||
|
.add_argument(
|
||||||
|
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||||
|
)
|
||||||
|
.add_argument(
|
||||||
|
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||||
|
)
|
||||||
|
.add_argument(
|
||||||
|
"created_by_end_user_session_id",
|
||||||
|
type=str,
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
.add_argument(
|
||||||
|
"created_by_account",
|
||||||
|
type=str,
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||||
|
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"created_by_end_user_session_id",
|
|
||||||
type=str,
|
|
||||||
location="args",
|
|
||||||
required=False,
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"created_by_account",
|
|
||||||
type=str,
|
|
||||||
location="args",
|
|
||||||
required=False,
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
|
||||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||||
|
|||||||
@ -22,8 +22,7 @@ from extensions.ext_database import db
|
|||||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||||
from factories.variable_factory import build_segment_with_type
|
from factories.variable_factory import build_segment_with_type
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_user, login_required
|
||||||
from models import App, AppMode
|
from models import Account, App, AppMode
|
||||||
from models.account import Account
|
|
||||||
from models.workflow import WorkflowDraftVariable
|
from models.workflow import WorkflowDraftVariable
|
||||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
@ -58,16 +57,18 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
|||||||
|
|
||||||
|
|
||||||
def _create_pagination_parser():
|
def _create_pagination_parser():
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"page",
|
.add_argument(
|
||||||
type=inputs.int_range(1, 100_000),
|
"page",
|
||||||
required=False,
|
type=inputs.int_range(1, 100_000),
|
||||||
default=1,
|
required=False,
|
||||||
location="args",
|
default=1,
|
||||||
help="the page of data requested",
|
location="args",
|
||||||
|
help="the page of data requested",
|
||||||
|
)
|
||||||
|
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
)
|
)
|
||||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -320,10 +321,11 @@ class VariableApi(Resource):
|
|||||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||||
# }
|
# }
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
# Parse 'value' field as-is to maintain its original data structure
|
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
|
|
||||||
draft_var_srv = WorkflowDraftVariableService(
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
session=db.session(),
|
session=db.session(),
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
|
|
||||||
@ -9,15 +8,81 @@ from controllers.console.app.wraps import get_app_model
|
|||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from fields.workflow_run_fields import (
|
from fields.workflow_run_fields import (
|
||||||
advanced_chat_workflow_run_pagination_fields,
|
advanced_chat_workflow_run_pagination_fields,
|
||||||
|
workflow_run_count_fields,
|
||||||
workflow_run_detail_fields,
|
workflow_run_detail_fields,
|
||||||
workflow_run_node_execution_list_fields,
|
workflow_run_node_execution_list_fields,
|
||||||
workflow_run_pagination_fields,
|
workflow_run_pagination_fields,
|
||||||
)
|
)
|
||||||
|
from libs.custom_inputs import time_duration
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
from models import Account, App, AppMode, EndUser
|
from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
|
||||||
from services.workflow_run_service import WorkflowRunService
|
from services.workflow_run_service import WorkflowRunService
|
||||||
|
|
||||||
|
# Workflow run status choices for filtering
|
||||||
|
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_workflow_run_list_args():
|
||||||
|
"""
|
||||||
|
Parse common arguments for workflow run list endpoints.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed arguments containing last_id, limit, status, and triggered_from filters
|
||||||
|
"""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||||
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
parser.add_argument(
|
||||||
|
"status",
|
||||||
|
type=str,
|
||||||
|
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"triggered_from",
|
||||||
|
type=str,
|
||||||
|
choices=["debugging", "app-run"],
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
help="Filter by trigger source: debugging or app-run",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_workflow_run_count_args():
|
||||||
|
"""
|
||||||
|
Parse common arguments for workflow run count endpoints.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed arguments containing status, time_range, and triggered_from filters
|
||||||
|
"""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"status",
|
||||||
|
type=str,
|
||||||
|
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"time_range",
|
||||||
|
type=time_duration,
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"triggered_from",
|
||||||
|
type=str,
|
||||||
|
choices=["debugging", "app-run"],
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
help="Filter by trigger source: debugging or app-run",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||||
class AdvancedChatAppWorkflowRunListApi(Resource):
|
class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||||
@ -25,6 +90,8 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
|||||||
@api.doc(description="Get advanced chat workflow run list")
|
@api.doc(description="Get advanced chat workflow run list")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||||
|
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||||
|
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||||
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
|
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -35,13 +102,64 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
|||||||
"""
|
"""
|
||||||
Get advanced chat app workflow run list
|
Get advanced chat app workflow run list
|
||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser()
|
args = _parse_workflow_run_list_args()
|
||||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
# Default to DEBUGGING if not specified
|
||||||
args = parser.parse_args()
|
triggered_from = (
|
||||||
|
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||||
|
if args.get("triggered_from")
|
||||||
|
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
|
)
|
||||||
|
|
||||||
workflow_run_service = WorkflowRunService()
|
workflow_run_service = WorkflowRunService()
|
||||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
|
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
|
||||||
|
app_model=app_model, args=args, triggered_from=triggered_from
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
|
||||||
|
class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||||
|
@api.doc("get_advanced_chat_workflow_runs_count")
|
||||||
|
@api.doc(description="Get advanced chat workflow runs count statistics")
|
||||||
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
|
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||||
|
@api.doc(
|
||||||
|
params={
|
||||||
|
"time_range": (
|
||||||
|
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||||
|
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||||
|
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||||
|
@marshal_with(workflow_run_count_fields)
|
||||||
|
def get(self, app_model: App):
|
||||||
|
"""
|
||||||
|
Get advanced chat workflow runs count statistics
|
||||||
|
"""
|
||||||
|
args = _parse_workflow_run_count_args()
|
||||||
|
|
||||||
|
# Default to DEBUGGING if not specified
|
||||||
|
triggered_from = (
|
||||||
|
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||||
|
if args.get("triggered_from")
|
||||||
|
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_run_service = WorkflowRunService()
|
||||||
|
result = workflow_run_service.get_workflow_runs_count(
|
||||||
|
app_model=app_model,
|
||||||
|
status=args.get("status"),
|
||||||
|
time_range=args.get("time_range"),
|
||||||
|
triggered_from=triggered_from,
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -52,6 +170,8 @@ class WorkflowRunListApi(Resource):
|
|||||||
@api.doc(description="Get workflow run list")
|
@api.doc(description="Get workflow run list")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||||
|
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||||
|
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||||
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
|
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -62,13 +182,64 @@ class WorkflowRunListApi(Resource):
|
|||||||
"""
|
"""
|
||||||
Get workflow run list
|
Get workflow run list
|
||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser()
|
args = _parse_workflow_run_list_args()
|
||||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||||
args = parser.parse_args()
|
triggered_from = (
|
||||||
|
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||||
|
if args.get("triggered_from")
|
||||||
|
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
|
)
|
||||||
|
|
||||||
workflow_run_service = WorkflowRunService()
|
workflow_run_service = WorkflowRunService()
|
||||||
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
|
result = workflow_run_service.get_paginate_workflow_runs(
|
||||||
|
app_model=app_model, args=args, triggered_from=triggered_from
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
|
||||||
|
class WorkflowRunCountApi(Resource):
|
||||||
|
@api.doc("get_workflow_runs_count")
|
||||||
|
@api.doc(description="Get workflow runs count statistics")
|
||||||
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
|
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||||
|
@api.doc(
|
||||||
|
params={
|
||||||
|
"time_range": (
|
||||||
|
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||||
|
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||||
|
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
@marshal_with(workflow_run_count_fields)
|
||||||
|
def get(self, app_model: App):
|
||||||
|
"""
|
||||||
|
Get workflow runs count statistics
|
||||||
|
"""
|
||||||
|
args = _parse_workflow_run_count_args()
|
||||||
|
|
||||||
|
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||||
|
triggered_from = (
|
||||||
|
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||||
|
if args.get("triggered_from")
|
||||||
|
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_run_service = WorkflowRunService()
|
||||||
|
result = workflow_run_service.get_workflow_runs_count(
|
||||||
|
app_model=app_model,
|
||||||
|
status=args.get("status"),
|
||||||
|
time_range=args.get("time_range"),
|
||||||
|
triggered_from=triggered_from,
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from decimal import Decimal
|
|||||||
import pytz
|
import pytz
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
@ -12,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
|
|||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import DatetimeString
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
@ -29,11 +28,13 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -49,7 +50,7 @@ WHERE
|
|||||||
"app_id": app_model.id,
|
"app_id": app_model.id,
|
||||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||||
}
|
}
|
||||||
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
@ -97,11 +98,13 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -117,7 +120,7 @@ WHERE
|
|||||||
"app_id": app_model.id,
|
"app_id": app_model.id,
|
||||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||||
}
|
}
|
||||||
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
@ -165,11 +168,13 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -185,7 +190,7 @@ WHERE
|
|||||||
"app_id": app_model.id,
|
"app_id": app_model.id,
|
||||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||||
}
|
}
|
||||||
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
@ -238,11 +243,13 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -271,7 +278,7 @@ GROUP BY
|
|||||||
"app_id": app_model.id,
|
"app_id": app_model.id,
|
||||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||||
}
|
}
|
||||||
|
assert account.timezone is not None
|
||||||
timezone = pytz.timezone(account.timezone)
|
timezone = pytz.timezone(account.timezone)
|
||||||
utc_timezone = pytz.utc
|
utc_timezone = pytz.utc
|
||||||
|
|
||||||
|
|||||||
@ -4,28 +4,29 @@ from typing import ParamSpec, TypeVar, Union
|
|||||||
|
|
||||||
from controllers.console.app.error import AppNotFoundError
|
from controllers.console.app.error import AppNotFoundError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models import App, AppMode
|
from models import App, AppMode
|
||||||
from models.account import Account
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
P1 = ParamSpec("P1")
|
||||||
|
R1 = TypeVar("R1")
|
||||||
|
|
||||||
|
|
||||||
def _load_app_model(app_id: str) -> App | None:
|
def _load_app_model(app_id: str) -> App | None:
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
app_model = (
|
app_model = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
||||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||||
def decorator(view_func: Callable[P, R]):
|
def decorator(view_func: Callable[P1, R1]):
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
|
||||||
if not kwargs.get("app_id"):
|
if not kwargs.get("app_id"):
|
||||||
raise ValueError("missing app_id in path parameters")
|
raise ValueError("missing app_id in path parameters")
|
||||||
|
|
||||||
|
|||||||
@ -7,18 +7,14 @@ from controllers.console.error import AlreadyActivateError
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
||||||
from models.account import AccountStatus
|
from models import AccountStatus
|
||||||
from services.account_service import AccountService, RegisterService
|
from services.account_service import AccountService, RegisterService
|
||||||
|
|
||||||
active_check_parser = reqparse.RequestParser()
|
active_check_parser = (
|
||||||
active_check_parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID"
|
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
|
||||||
)
|
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
|
||||||
active_check_parser.add_argument(
|
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
|
||||||
"email", type=email, required=False, nullable=True, location="args", help="Email address"
|
|
||||||
)
|
|
||||||
active_check_parser.add_argument(
|
|
||||||
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -60,15 +56,15 @@ class ActivateCheckApi(Resource):
|
|||||||
return {"is_valid": False}
|
return {"is_valid": False}
|
||||||
|
|
||||||
|
|
||||||
active_parser = reqparse.RequestParser()
|
active_parser = (
|
||||||
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||||
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||||
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
active_parser.add_argument(
|
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/activate")
|
@console_ns.route("/activate")
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||||
|
|
||||||
from ..wraps import account_initialization_required, setup_required
|
from ..wraps import account_initialization_required, setup_required
|
||||||
@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||||
if data_source_api_key_bindings:
|
if data_source_api_key_bindings:
|
||||||
return {
|
return {
|
||||||
"sources": [
|
"sources": [
|
||||||
@ -41,16 +41,20 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the table must be admin or owner
|
# The role of the current user in the table must be admin or owner
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||||
try:
|
try:
|
||||||
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ApiKeyAuthFailedError(str(e))
|
raise ApiKeyAuthFailedError(str(e))
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
@ -63,9 +67,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, binding_id):
|
def delete(self, binding_id):
|
||||||
# The role of the current user in the table must be admin or owner
|
# The role of the current user in the table must be admin or owner
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|||||||
@ -2,13 +2,12 @@ import logging
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from flask import current_app, redirect, request
|
from flask import current_app, redirect, request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from libs.oauth_data_source import NotionOAuth
|
from libs.oauth_data_source import NotionOAuth
|
||||||
|
|
||||||
from ..wraps import account_initialization_required, setup_required
|
from ..wraps import account_initialization_required, setup_required
|
||||||
@ -45,6 +44,7 @@ class OAuthDataSource(Resource):
|
|||||||
@api.response(403, "Admin privileges required")
|
@api.response(403, "Admin privileges required")
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
# The role of the current user in the table must be admin or owner
|
# The role of the current user in the table must be admin or owner
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from controllers.console.wraps import email_password_login_enabled, email_regist
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import email, extract_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||||
@ -31,9 +31,11 @@ class EmailRegisterSendEmailApi(Resource):
|
|||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
@email_register_enabled
|
@email_register_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("language", type=str, required=False, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
@ -59,10 +61,12 @@ class EmailRegisterCheckApi(Resource):
|
|||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
@email_register_enabled
|
@email_register_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("code", type=str, required=True, location="json")
|
.add_argument("email", type=str, required=True, location="json")
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
.add_argument("code", type=str, required=True, location="json")
|
||||||
|
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args["email"]
|
||||||
@ -100,10 +104,12 @@ class EmailRegisterResetApi(Resource):
|
|||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
@email_register_enabled
|
@email_register_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate passwords match
|
# Validate passwords match
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from events.tenant_event import tenant_was_created
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import email, extract_remote_ip
|
||||||
from libs.password import hash_password, valid_password
|
from libs.password import hash_password, valid_password
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from services.account_service import AccountService, TenantService
|
from services.account_service import AccountService, TenantService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
@ -54,9 +54,11 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("language", type=str, required=False, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
@ -111,10 +113,12 @@ class ForgotPasswordCheckApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("code", type=str, required=True, location="json")
|
.add_argument("email", type=str, required=True, location="json")
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
.add_argument("code", type=str, required=True, location="json")
|
||||||
|
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args["email"]
|
||||||
@ -169,10 +173,12 @@ class ForgotPasswordResetApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate passwords match
|
# Validate passwords match
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
import flask_login
|
import flask_login
|
||||||
from flask import request
|
from flask import make_response, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -26,7 +24,17 @@ from controllers.console.error import (
|
|||||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import email, extract_remote_ip
|
||||||
from models.account import Account
|
from libs.login import current_account_with_tenant
|
||||||
|
from libs.token import (
|
||||||
|
clear_access_token_from_cookie,
|
||||||
|
clear_csrf_token_from_cookie,
|
||||||
|
clear_refresh_token_from_cookie,
|
||||||
|
extract_access_token,
|
||||||
|
extract_csrf_token,
|
||||||
|
set_access_token_to_cookie,
|
||||||
|
set_csrf_token_to_cookie,
|
||||||
|
set_refresh_token_to_cookie,
|
||||||
|
)
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.errors.account import AccountRegisterError
|
from services.errors.account import AccountRegisterError
|
||||||
@ -42,11 +50,13 @@ class LoginApi(Resource):
|
|||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Authenticate user and login."""
|
"""Authenticate user and login."""
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("password", type=str, required=True, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
.add_argument("password", type=str, required=True, location="json")
|
||||||
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||||
|
.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||||
@ -89,19 +99,36 @@ class LoginApi(Resource):
|
|||||||
|
|
||||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||||
AccountService.reset_login_error_rate_limit(args["email"])
|
AccountService.reset_login_error_rate_limit(args["email"])
|
||||||
return {"result": "success", "data": token_pair.model_dump()}
|
|
||||||
|
# Create response with cookies instead of returning tokens in body
|
||||||
|
response = make_response({"result": "success"})
|
||||||
|
|
||||||
|
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||||
|
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||||
|
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/logout")
|
@console_ns.route("/logout")
|
||||||
class LogoutApi(Resource):
|
class LogoutApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def get(self):
|
def post(self):
|
||||||
account = cast(Account, flask_login.current_user)
|
current_user, _ = current_account_with_tenant()
|
||||||
|
account = current_user
|
||||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||||
return {"result": "success"}
|
response = make_response({"result": "success"})
|
||||||
AccountService.logout(account=account)
|
else:
|
||||||
flask_login.logout_user()
|
AccountService.logout(account=account)
|
||||||
return {"result": "success"}
|
flask_login.logout_user()
|
||||||
|
response = make_response({"result": "success"})
|
||||||
|
|
||||||
|
# Clear cookies on logout
|
||||||
|
clear_access_token_from_cookie(response)
|
||||||
|
clear_refresh_token_from_cookie(response)
|
||||||
|
clear_csrf_token_from_cookie(response)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/reset-password")
|
@console_ns.route("/reset-password")
|
||||||
@ -109,9 +136,11 @@ class ResetPasswordSendEmailApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("language", type=str, required=False, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||||
@ -137,9 +166,11 @@ class ResetPasswordSendEmailApi(Resource):
|
|||||||
class EmailCodeLoginSendEmailApi(Resource):
|
class EmailCodeLoginSendEmailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("language", type=str, required=False, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
@ -170,10 +201,12 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||||||
class EmailCodeLoginApi(Resource):
|
class EmailCodeLoginApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("code", type=str, required=True, location="json")
|
.add_argument("email", type=str, required=True, location="json")
|
||||||
parser.add_argument("token", type=str, required=True, location="json")
|
.add_argument("code", type=str, required=True, location="json")
|
||||||
|
.add_argument("token", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args["email"]
|
||||||
@ -220,18 +253,46 @@ class EmailCodeLoginApi(Resource):
|
|||||||
raise WorkspacesLimitExceeded()
|
raise WorkspacesLimitExceeded()
|
||||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||||
AccountService.reset_login_error_rate_limit(args["email"])
|
AccountService.reset_login_error_rate_limit(args["email"])
|
||||||
return {"result": "success", "data": token_pair.model_dump()}
|
|
||||||
|
# Create response with cookies instead of returning tokens in body
|
||||||
|
response = make_response({"result": "success"})
|
||||||
|
|
||||||
|
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||||
|
# Set HTTP-only secure cookies for tokens
|
||||||
|
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||||
|
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/refresh-token")
|
@console_ns.route("/refresh-token")
|
||||||
class RefreshTokenApi(Resource):
|
class RefreshTokenApi(Resource):
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
# Get refresh token from cookie instead of request body
|
||||||
parser.add_argument("refresh_token", type=str, required=True, location="json")
|
refresh_token = request.cookies.get("refresh_token")
|
||||||
args = parser.parse_args()
|
|
||||||
|
if not refresh_token:
|
||||||
|
return {"result": "fail", "message": "No refresh token provided"}, 401
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_token_pair = AccountService.refresh_token(args["refresh_token"])
|
new_token_pair = AccountService.refresh_token(refresh_token)
|
||||||
return {"result": "success", "data": new_token_pair.model_dump()}
|
|
||||||
|
# Create response with new cookies
|
||||||
|
response = make_response({"result": "success"})
|
||||||
|
|
||||||
|
# Update cookies with new tokens
|
||||||
|
set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token)
|
||||||
|
set_access_token_to_cookie(request, response, new_token_pair.access_token)
|
||||||
|
set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token)
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"result": "fail", "data": str(e)}, 401
|
return {"result": "fail", "message": str(e)}, 401
|
||||||
|
|
||||||
|
|
||||||
|
# this api helps frontend to check whether user is authenticated
|
||||||
|
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
|
||||||
|
@console_ns.route("/login/status")
|
||||||
|
class LoginStatus(Resource):
|
||||||
|
def get(self):
|
||||||
|
token = extract_access_token(request)
|
||||||
|
csrf_token = extract_csrf_token(request)
|
||||||
|
return {"logged_in": bool(token) and bool(csrf_token)}
|
||||||
|
|||||||
@ -14,8 +14,12 @@ from extensions.ext_database import db
|
|||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import extract_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||||
from models import Account
|
from libs.token import (
|
||||||
from models.account import AccountStatus
|
set_access_token_to_cookie,
|
||||||
|
set_csrf_token_to_cookie,
|
||||||
|
set_refresh_token_to_cookie,
|
||||||
|
)
|
||||||
|
from models import Account, AccountStatus
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||||
@ -153,9 +157,12 @@ class OAuthCallback(Resource):
|
|||||||
ip_address=extract_remote_ip(request),
|
ip_address=extract_remote_ip(request),
|
||||||
)
|
)
|
||||||
|
|
||||||
return redirect(
|
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||||
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
|
|
||||||
)
|
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||||
|
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||||
|
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:
|
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:
|
||||||
|
|||||||
@ -1,16 +1,15 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
from typing import Concatenate, ParamSpec, TypeVar
|
||||||
|
|
||||||
import flask_login
|
|
||||||
from flask import jsonify, request
|
from flask import jsonify, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from models.model import OAuthProviderApp
|
from models.model import OAuthProviderApp
|
||||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||||
|
|
||||||
@ -24,8 +23,7 @@ T = TypeVar("T")
|
|||||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
|
||||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
|
||||||
parsed_args = parser.parse_args()
|
parsed_args = parser.parse_args()
|
||||||
client_id = parsed_args.get("client_id")
|
client_id = parsed_args.get("client_id")
|
||||||
if not client_id:
|
if not client_id:
|
||||||
@ -91,8 +89,7 @@ class OAuthServerAppApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@oauth_server_client_id_required
|
@oauth_server_client_id_required
|
||||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
|
||||||
parser.add_argument("redirect_uri", type=str, required=True, location="json")
|
|
||||||
parsed_args = parser.parse_args()
|
parsed_args = parser.parse_args()
|
||||||
redirect_uri = parsed_args.get("redirect_uri")
|
redirect_uri = parsed_args.get("redirect_uri")
|
||||||
|
|
||||||
@ -116,7 +113,8 @@ class OAuthServerUserAuthorizeApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@oauth_server_client_id_required
|
@oauth_server_client_id_required
|
||||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||||
account = cast(Account, flask_login.current_user)
|
current_user, _ = current_account_with_tenant()
|
||||||
|
account = current_user
|
||||||
user_account_id = account.id
|
user_account_id = account.id
|
||||||
|
|
||||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||||
@ -132,12 +130,14 @@ class OAuthServerUserTokenApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@oauth_server_client_id_required
|
@oauth_server_client_id_required
|
||||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("grant_type", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("code", type=str, required=False, location="json")
|
.add_argument("grant_type", type=str, required=True, location="json")
|
||||||
parser.add_argument("client_secret", type=str, required=False, location="json")
|
.add_argument("code", type=str, required=False, location="json")
|
||||||
parser.add_argument("redirect_uri", type=str, required=False, location="json")
|
.add_argument("client_secret", type=str, required=False, location="json")
|
||||||
parser.add_argument("refresh_token", type=str, required=False, location="json")
|
.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||||
|
.add_argument("refresh_token", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
parsed_args = parser.parse_args()
|
parsed_args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -2,8 +2,7 @@ from flask_restx import Resource, reqparse
|
|||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.model import Account
|
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
|
|
||||||
@ -14,17 +13,15 @@ class Subscription(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
parser = (
|
||||||
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
reqparse.RequestParser()
|
||||||
args = parser.parse_args()
|
.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||||
assert isinstance(current_user, Account)
|
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||||
|
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
return BillingService.get_subscription(
|
|
||||||
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
|
||||||
)
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
|
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/billing/invoices")
|
@console_ns.route("/billing/invoices")
|
||||||
@ -34,7 +31,6 @@ class Invoices(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
assert current_user.current_tenant_id is not None
|
return BillingService.get_invoices(current_user.email, current_tenant_id)
|
||||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
|
||||||
|
|||||||
@ -2,8 +2,7 @@ from flask import request
|
|||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
|
||||||
from libs.helper import extract_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
from .. import console_ns
|
from .. import console_ns
|
||||||
@ -17,19 +16,16 @@ class ComplianceApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("doc_name", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
device_info = request.headers.get("User-Agent", "Unknown device")
|
device_info = request.headers.get("User-Agent", "Unknown device")
|
||||||
|
|
||||||
return BillingService.get_compliance_download_link(
|
return BillingService.get_compliance_download_link(
|
||||||
doc_name=args.doc_name,
|
doc_name=args.doc_name,
|
||||||
account_id=current_user.id,
|
account_id=current_user.id,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
ip=ip_address,
|
ip=ip_address,
|
||||||
device_info=device_info,
|
device_info=device_info,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from collections.abc import Generator
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -20,7 +19,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import DataSourceOauthBinding, Document
|
from models import DataSourceOauthBinding, Document
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.datasource_provider_service import DatasourceProviderService
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
@ -37,10 +36,12 @@ class DataSourceApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(integrate_list_fields)
|
@marshal_with(integrate_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# get workspace data source integrates
|
# get workspace data source integrates
|
||||||
data_source_integrates = db.session.scalars(
|
data_source_integrates = db.session.scalars(
|
||||||
select(DataSourceOauthBinding).where(
|
select(DataSourceOauthBinding).where(
|
||||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_tenant_id,
|
||||||
DataSourceOauthBinding.disabled == False,
|
DataSourceOauthBinding.disabled == False,
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(integrate_notion_info_list_fields)
|
@marshal_with(integrate_notion_info_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||||
if not credential_id:
|
if not credential_id:
|
||||||
raise ValueError("Credential id is required.")
|
raise ValueError("Credential id is required.")
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
credential = datasource_provider_service.get_datasource_credentials(
|
credential = datasource_provider_service.get_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
credential_id=credential_id,
|
credential_id=credential_id,
|
||||||
provider="notion_datasource",
|
provider="notion_datasource",
|
||||||
plugin_id="langgenius/notion_datasource",
|
plugin_id="langgenius/notion_datasource",
|
||||||
@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource):
|
|||||||
documents = session.scalars(
|
documents = session.scalars(
|
||||||
select(Document).filter_by(
|
select(Document).filter_by(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
data_source_type="notion_import",
|
data_source_type="notion_import",
|
||||||
enabled=True,
|
enabled=True,
|
||||||
)
|
)
|
||||||
@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource):
|
|||||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||||
provider_id="langgenius/notion_datasource/notion_datasource",
|
provider_id="langgenius/notion_datasource/notion_datasource",
|
||||||
datasource_name="notion_datasource",
|
datasource_name="notion_datasource",
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||||
)
|
)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, workspace_id, page_id, page_type):
|
def get(self, workspace_id, page_id, page_type):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||||
if not credential_id:
|
if not credential_id:
|
||||||
raise ValueError("Credential id is required.")
|
raise ValueError("Credential id is required.")
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
credential = datasource_provider_service.get_datasource_credentials(
|
credential = datasource_provider_service.get_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
credential_id=credential_id,
|
credential_id=credential_id,
|
||||||
provider="notion_datasource",
|
provider="notion_datasource",
|
||||||
plugin_id="langgenius/notion_datasource",
|
plugin_id="langgenius/notion_datasource",
|
||||||
@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource):
|
|||||||
notion_obj_id=page_id,
|
notion_obj_id=page_id,
|
||||||
notion_page_type=page_type,
|
notion_page_type=page_type,
|
||||||
notion_access_token=credential.get("integration_secret"),
|
notion_access_token=credential.get("integration_secret"),
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
text_docs = extractor.extract()
|
text_docs = extractor.extract()
|
||||||
@ -239,12 +244,14 @@ class DataSourceNotionApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
|
||||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
parser = (
|
||||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||||
|
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
|
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# validate args
|
# validate args
|
||||||
@ -263,7 +270,7 @@ class DataSourceNotionApi(Resource):
|
|||||||
"notion_workspace_id": workspace_id,
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_obj_id": page["page_id"],
|
"notion_obj_id": page["page_id"],
|
||||||
"notion_page_type": page["type"],
|
"notion_page_type": page["type"],
|
||||||
"tenant_id": current_user.current_tenant_id,
|
"tenant_id": current_tenant_id,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
document_model=args["doc_form"],
|
document_model=args["doc_form"],
|
||||||
@ -271,7 +278,7 @@ class DataSourceNotionApi(Resource):
|
|||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
response = indexing_runner.indexing_estimate(
|
response = indexing_runner.indexing_estimate(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
extract_settings,
|
extract_settings,
|
||||||
args["process_rule"],
|
args["process_rule"],
|
||||||
args["doc_form"],
|
args["doc_form"],
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
@ -30,10 +29,9 @@ from extensions.ext_database import db
|
|||||||
from fields.app_fields import related_app_list
|
from fields.app_fields import related_app_list
|
||||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||||
from fields.document_fields import document_status_fields
|
from fields.document_fields import document_status_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from libs.validators import validate_description_length
|
from libs.validators import validate_description_length
|
||||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||||
from models.account import Account
|
|
||||||
from models.dataset import DatasetPermissionEnum
|
from models.dataset import DatasetPermissionEnum
|
||||||
from models.provider_ids import ModelProviderID
|
from models.provider_ids import ModelProviderID
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||||
@ -138,6 +136,7 @@ class DatasetListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
ids = request.args.getlist("ids")
|
ids = request.args.getlist("ids")
|
||||||
@ -146,15 +145,15 @@ class DatasetListApi(Resource):
|
|||||||
tag_ids = request.args.getlist("tag_ids")
|
tag_ids = request.args.getlist("tag_ids")
|
||||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||||
if ids:
|
if ids:
|
||||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
|
||||||
else:
|
else:
|
||||||
datasets, total = DatasetService.get_datasets(
|
datasets, total = DatasetService.get_datasets(
|
||||||
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
|
page, limit, current_tenant_id, current_user, search, tag_ids, include_all
|
||||||
)
|
)
|
||||||
|
|
||||||
# check embedding setting
|
# check embedding setting
|
||||||
provider_manager = ProviderManager()
|
provider_manager = ProviderManager()
|
||||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
|
||||||
|
|
||||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||||
|
|
||||||
@ -207,50 +206,53 @@ class DatasetListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"name",
|
.add_argument(
|
||||||
nullable=False,
|
"name",
|
||||||
required=True,
|
nullable=False,
|
||||||
help="type is required. Name must be between 1 to 40 characters.",
|
required=True,
|
||||||
type=_validate_name,
|
help="type is required. Name must be between 1 to 40 characters.",
|
||||||
)
|
type=_validate_name,
|
||||||
parser.add_argument(
|
)
|
||||||
"description",
|
.add_argument(
|
||||||
type=validate_description_length,
|
"description",
|
||||||
nullable=True,
|
type=validate_description_length,
|
||||||
required=False,
|
nullable=True,
|
||||||
default="",
|
required=False,
|
||||||
)
|
default="",
|
||||||
parser.add_argument(
|
)
|
||||||
"indexing_technique",
|
.add_argument(
|
||||||
type=str,
|
"indexing_technique",
|
||||||
location="json",
|
type=str,
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
location="json",
|
||||||
nullable=True,
|
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||||
help="Invalid indexing technique.",
|
nullable=True,
|
||||||
)
|
help="Invalid indexing technique.",
|
||||||
parser.add_argument(
|
)
|
||||||
"external_knowledge_api_id",
|
.add_argument(
|
||||||
type=str,
|
"external_knowledge_api_id",
|
||||||
nullable=True,
|
type=str,
|
||||||
required=False,
|
nullable=True,
|
||||||
)
|
required=False,
|
||||||
parser.add_argument(
|
)
|
||||||
"provider",
|
.add_argument(
|
||||||
type=str,
|
"provider",
|
||||||
nullable=True,
|
type=str,
|
||||||
choices=Dataset.PROVIDER_LIST,
|
nullable=True,
|
||||||
required=False,
|
choices=Dataset.PROVIDER_LIST,
|
||||||
default="vendor",
|
required=False,
|
||||||
)
|
default="vendor",
|
||||||
parser.add_argument(
|
)
|
||||||
"external_knowledge_id",
|
.add_argument(
|
||||||
type=str,
|
"external_knowledge_id",
|
||||||
nullable=True,
|
type=str,
|
||||||
required=False,
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
@ -258,11 +260,11 @@ class DatasetListApi(Resource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
dataset = DatasetService.create_empty_dataset(
|
dataset = DatasetService.create_empty_dataset(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
description=args["description"],
|
description=args["description"],
|
||||||
indexing_technique=args["indexing_technique"],
|
indexing_technique=args["indexing_technique"],
|
||||||
account=cast(Account, current_user),
|
account=current_user,
|
||||||
permission=DatasetPermissionEnum.ONLY_ME,
|
permission=DatasetPermissionEnum.ONLY_ME,
|
||||||
provider=args["provider"],
|
provider=args["provider"],
|
||||||
external_knowledge_api_id=args["external_knowledge_api_id"],
|
external_knowledge_api_id=args["external_knowledge_api_id"],
|
||||||
@ -286,6 +288,7 @@ class DatasetApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@ -305,7 +308,7 @@ class DatasetApi(Resource):
|
|||||||
|
|
||||||
# check embedding setting
|
# check embedding setting
|
||||||
provider_manager = ProviderManager()
|
provider_manager = ProviderManager()
|
||||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
|
||||||
|
|
||||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||||
|
|
||||||
@ -351,73 +354,76 @@ class DatasetApi(Resource):
|
|||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"name",
|
.add_argument(
|
||||||
nullable=False,
|
"name",
|
||||||
help="type is required. Name must be between 1 to 40 characters.",
|
nullable=False,
|
||||||
type=_validate_name,
|
help="type is required. Name must be between 1 to 40 characters.",
|
||||||
)
|
type=_validate_name,
|
||||||
parser.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
)
|
||||||
parser.add_argument(
|
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
||||||
"indexing_technique",
|
.add_argument(
|
||||||
type=str,
|
"indexing_technique",
|
||||||
location="json",
|
type=str,
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
location="json",
|
||||||
nullable=True,
|
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||||
help="Invalid indexing technique.",
|
nullable=True,
|
||||||
)
|
help="Invalid indexing technique.",
|
||||||
parser.add_argument(
|
)
|
||||||
"permission",
|
.add_argument(
|
||||||
type=str,
|
"permission",
|
||||||
location="json",
|
type=str,
|
||||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
location="json",
|
||||||
help="Invalid permission.",
|
choices=(
|
||||||
)
|
DatasetPermissionEnum.ONLY_ME,
|
||||||
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
DatasetPermissionEnum.ALL_TEAM,
|
||||||
parser.add_argument(
|
DatasetPermissionEnum.PARTIAL_TEAM,
|
||||||
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
|
),
|
||||||
)
|
help="Invalid permission.",
|
||||||
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
)
|
||||||
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
||||||
|
.add_argument(
|
||||||
parser.add_argument(
|
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
|
||||||
"external_retrieval_model",
|
)
|
||||||
type=dict,
|
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||||
required=False,
|
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||||
nullable=True,
|
.add_argument(
|
||||||
location="json",
|
"external_retrieval_model",
|
||||||
help="Invalid external retrieval model.",
|
type=dict,
|
||||||
)
|
required=False,
|
||||||
|
nullable=True,
|
||||||
parser.add_argument(
|
location="json",
|
||||||
"external_knowledge_id",
|
help="Invalid external retrieval model.",
|
||||||
type=str,
|
)
|
||||||
required=False,
|
.add_argument(
|
||||||
nullable=True,
|
"external_knowledge_id",
|
||||||
location="json",
|
type=str,
|
||||||
help="Invalid external knowledge id.",
|
required=False,
|
||||||
)
|
nullable=True,
|
||||||
|
location="json",
|
||||||
parser.add_argument(
|
help="Invalid external knowledge id.",
|
||||||
"external_knowledge_api_id",
|
)
|
||||||
type=str,
|
.add_argument(
|
||||||
required=False,
|
"external_knowledge_api_id",
|
||||||
nullable=True,
|
type=str,
|
||||||
location="json",
|
required=False,
|
||||||
help="Invalid external knowledge api id.",
|
nullable=True,
|
||||||
)
|
location="json",
|
||||||
|
help="Invalid external knowledge api id.",
|
||||||
parser.add_argument(
|
)
|
||||||
"icon_info",
|
.add_argument(
|
||||||
type=dict,
|
"icon_info",
|
||||||
required=False,
|
type=dict,
|
||||||
nullable=True,
|
required=False,
|
||||||
location="json",
|
nullable=True,
|
||||||
help="Invalid icon info.",
|
location="json",
|
||||||
|
help="Invalid icon info.",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
if (
|
if (
|
||||||
@ -440,7 +446,7 @@ class DatasetApi(Resource):
|
|||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
|
|
||||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||||
DatasetPermissionService.update_partial_member_list(
|
DatasetPermissionService.update_partial_member_list(
|
||||||
@ -464,9 +470,9 @@ class DatasetApi(Resource):
|
|||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def delete(self, dataset_id):
|
def delete(self, dataset_id):
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||||
if not (current_user.is_editor or current_user.is_dataset_operator):
|
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -505,6 +511,7 @@ class DatasetQueryApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@ -539,32 +546,31 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
|
||||||
parser.add_argument(
|
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||||
"indexing_technique",
|
.add_argument(
|
||||||
type=str,
|
"indexing_technique",
|
||||||
required=True,
|
type=str,
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
required=True,
|
||||||
nullable=True,
|
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||||
location="json",
|
nullable=True,
|
||||||
)
|
location="json",
|
||||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
)
|
||||||
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
|
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
parser.add_argument(
|
.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
|
||||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
# validate args
|
# validate args
|
||||||
DocumentService.estimate_args_validate(args)
|
DocumentService.estimate_args_validate(args)
|
||||||
extract_settings = []
|
extract_settings = []
|
||||||
if args["info_list"]["data_source_type"] == "upload_file":
|
if args["info_list"]["data_source_type"] == "upload_file":
|
||||||
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||||
file_details = db.session.scalars(
|
file_details = db.session.scalars(
|
||||||
select(UploadFile).where(
|
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
|
||||||
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
|
|
||||||
)
|
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
if file_details is None:
|
if file_details is None:
|
||||||
@ -592,7 +598,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
"notion_workspace_id": workspace_id,
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_obj_id": page["page_id"],
|
"notion_obj_id": page["page_id"],
|
||||||
"notion_page_type": page["type"],
|
"notion_page_type": page["type"],
|
||||||
"tenant_id": current_user.current_tenant_id,
|
"tenant_id": current_tenant_id,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
document_model=args["doc_form"],
|
document_model=args["doc_form"],
|
||||||
@ -608,7 +614,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
"provider": website_info_list["provider"],
|
"provider": website_info_list["provider"],
|
||||||
"job_id": website_info_list["job_id"],
|
"job_id": website_info_list["job_id"],
|
||||||
"url": url,
|
"url": url,
|
||||||
"tenant_id": current_user.current_tenant_id,
|
"tenant_id": current_tenant_id,
|
||||||
"mode": "crawl",
|
"mode": "crawl",
|
||||||
"only_main_content": website_info_list["only_main_content"],
|
"only_main_content": website_info_list["only_main_content"],
|
||||||
}
|
}
|
||||||
@ -621,7 +627,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
try:
|
try:
|
||||||
response = indexing_runner.indexing_estimate(
|
response = indexing_runner.indexing_estimate(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
extract_settings,
|
extract_settings,
|
||||||
args["process_rule"],
|
args["process_rule"],
|
||||||
args["doc_form"],
|
args["doc_form"],
|
||||||
@ -652,6 +658,7 @@ class DatasetRelatedAppListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(related_app_list)
|
@marshal_with(related_app_list)
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@ -683,11 +690,10 @@ class DatasetIndexingStatusApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
documents = db.session.scalars(
|
documents = db.session.scalars(
|
||||||
select(Document).where(
|
select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
|
||||||
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
).all()
|
).all()
|
||||||
documents_status = []
|
documents_status = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
@ -739,10 +745,9 @@ class DatasetApiKeyApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_key_list)
|
@marshal_with(api_key_list)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
keys = db.session.scalars(
|
keys = db.session.scalars(
|
||||||
select(ApiToken).where(
|
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||||
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
).all()
|
).all()
|
||||||
return {"items": keys}
|
return {"items": keys}
|
||||||
|
|
||||||
@ -752,12 +757,13 @@ class DatasetApiKeyApi(Resource):
|
|||||||
@marshal_with(api_key_fields)
|
@marshal_with(api_key_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
current_key_count = (
|
current_key_count = (
|
||||||
db.session.query(ApiToken)
|
db.session.query(ApiToken)
|
||||||
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
|
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||||
.count()
|
.count()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -770,7 +776,7 @@ class DatasetApiKeyApi(Resource):
|
|||||||
|
|
||||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||||
api_token = ApiToken()
|
api_token = ApiToken()
|
||||||
api_token.tenant_id = current_user.current_tenant_id
|
api_token.tenant_id = current_tenant_id
|
||||||
api_token.token = key
|
api_token.token = key
|
||||||
api_token.type = self.resource_type
|
api_token.type = self.resource_type
|
||||||
db.session.add(api_token)
|
db.session.add(api_token)
|
||||||
@ -790,6 +796,7 @@ class DatasetApiDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, api_key_id):
|
def delete(self, api_key_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
api_key_id = str(api_key_id)
|
api_key_id = str(api_key_id)
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
@ -799,7 +806,7 @@ class DatasetApiDeleteApi(Resource):
|
|||||||
key = (
|
key = (
|
||||||
db.session.query(ApiToken)
|
db.session.query(ApiToken)
|
||||||
.where(
|
.where(
|
||||||
ApiToken.tenant_id == current_user.current_tenant_id,
|
ApiToken.tenant_id == current_tenant_id,
|
||||||
ApiToken.type == self.resource_type,
|
ApiToken.type == self.resource_type,
|
||||||
ApiToken.id == api_key_id,
|
ApiToken.id == api_key_id,
|
||||||
)
|
)
|
||||||
@ -898,6 +905,7 @@ class DatasetPermissionUserListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from typing import Literal, cast
|
|||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||||
from sqlalchemy import asc, desc, select
|
from sqlalchemy import asc, desc, select
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
@ -53,9 +52,8 @@ from fields.document_fields import (
|
|||||||
document_with_segments_fields,
|
document_with_segments_fields,
|
||||||
)
|
)
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||||
from models.account import Account
|
|
||||||
from models.dataset import DocumentPipelineExecutionLog
|
from models.dataset import DocumentPipelineExecutionLog
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||||
@ -65,6 +63,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class DocumentResource(Resource):
|
class DocumentResource(Resource):
|
||||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
@ -79,12 +78,13 @@ class DocumentResource(Resource):
|
|||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
if document.tenant_id != current_user.current_tenant_id:
|
if document.tenant_id != current_tenant_id:
|
||||||
raise Forbidden("No permission.")
|
raise Forbidden("No permission.")
|
||||||
|
|
||||||
return document
|
return document
|
||||||
|
|
||||||
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
|
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
@ -112,6 +112,7 @@ class GetProcessRuleApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
req_data = request.args
|
req_data = request.args
|
||||||
|
|
||||||
document_id = req_data.get("document_id")
|
document_id = req_data.get("document_id")
|
||||||
@ -168,6 +169,7 @@ class DatasetDocumentListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
@ -199,7 +201,7 @@ class DatasetDocumentListApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
|
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
|
||||||
|
|
||||||
if search:
|
if search:
|
||||||
search = f"%{search}%"
|
search = f"%{search}%"
|
||||||
@ -273,6 +275,7 @@ class DatasetDocumentListApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
|
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -289,20 +292,20 @@ class DatasetDocumentListApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
.add_argument(
|
||||||
)
|
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||||
parser.add_argument("data_source", type=dict, required=False, location="json")
|
)
|
||||||
parser.add_argument("process_rule", type=dict, required=False, location="json")
|
.add_argument("data_source", type=dict, required=False, location="json")
|
||||||
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
|
.add_argument("process_rule", type=dict, required=False, location="json")
|
||||||
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
|
||||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument(
|
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||||
@ -372,27 +375,28 @@ class DatasetInitApi(Resource):
|
|||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"indexing_technique",
|
.add_argument(
|
||||||
type=str,
|
"indexing_technique",
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
type=str,
|
||||||
required=True,
|
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||||
nullable=False,
|
required=True,
|
||||||
location="json",
|
nullable=False,
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
||||||
|
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||||
|
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
|
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||||
|
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||||
|
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
|
||||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
|
||||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
|
||||||
parser.add_argument(
|
|
||||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
|
||||||
)
|
|
||||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
|
||||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
|
||||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||||
@ -402,7 +406,7 @@ class DatasetInitApi(Resource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=args["embedding_model_provider"],
|
provider=args["embedding_model_provider"],
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=args["embedding_model"],
|
model=args["embedding_model"],
|
||||||
@ -419,9 +423,9 @@ class DatasetInitApi(Resource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
|
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
knowledge_config=knowledge_config,
|
knowledge_config=knowledge_config,
|
||||||
account=cast(Account, current_user),
|
account=current_user,
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
@ -447,6 +451,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id, document_id):
|
def get(self, dataset_id, document_id):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = self.get_document(dataset_id, document_id)
|
document = self.get_document(dataset_id, document_id)
|
||||||
@ -482,7 +487,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
estimate_response = indexing_runner.indexing_estimate(
|
estimate_response = indexing_runner.indexing_estimate(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
[extract_setting],
|
[extract_setting],
|
||||||
data_process_rule_dict,
|
data_process_rule_dict,
|
||||||
document.doc_form,
|
document.doc_form,
|
||||||
@ -511,6 +516,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id, batch):
|
def get(self, dataset_id, batch):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
batch = str(batch)
|
batch = str(batch)
|
||||||
documents = self.get_batch_documents(dataset_id, batch)
|
documents = self.get_batch_documents(dataset_id, batch)
|
||||||
@ -530,7 +536,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
file_id = data_source_info["upload_file_id"]
|
file_id = data_source_info["upload_file_id"]
|
||||||
file_detail = (
|
file_detail = (
|
||||||
db.session.query(UploadFile)
|
db.session.query(UploadFile)
|
||||||
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
|
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -553,7 +559,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||||
"notion_obj_id": data_source_info["notion_page_id"],
|
"notion_obj_id": data_source_info["notion_page_id"],
|
||||||
"notion_page_type": data_source_info["type"],
|
"notion_page_type": data_source_info["type"],
|
||||||
"tenant_id": current_user.current_tenant_id,
|
"tenant_id": current_tenant_id,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
document_model=document.doc_form,
|
document_model=document.doc_form,
|
||||||
@ -569,7 +575,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
"provider": data_source_info["provider"],
|
"provider": data_source_info["provider"],
|
||||||
"job_id": data_source_info["job_id"],
|
"job_id": data_source_info["job_id"],
|
||||||
"url": data_source_info["url"],
|
"url": data_source_info["url"],
|
||||||
"tenant_id": current_user.current_tenant_id,
|
"tenant_id": current_tenant_id,
|
||||||
"mode": data_source_info["mode"],
|
"mode": data_source_info["mode"],
|
||||||
"only_main_content": data_source_info["only_main_content"],
|
"only_main_content": data_source_info["only_main_content"],
|
||||||
}
|
}
|
||||||
@ -583,7 +589,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
try:
|
try:
|
||||||
response = indexing_runner.indexing_estimate(
|
response = indexing_runner.indexing_estimate(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
extract_settings,
|
extract_settings,
|
||||||
data_process_rule_dict,
|
data_process_rule_dict,
|
||||||
document.doc_form,
|
document.doc_form,
|
||||||
@ -834,6 +840,7 @@ class DocumentProcessingApi(DocumentResource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
|
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = self.get_document(dataset_id, document_id)
|
document = self.get_document(dataset_id, document_id)
|
||||||
@ -884,6 +891,7 @@ class DocumentMetadataApi(DocumentResource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def put(self, dataset_id, document_id):
|
def put(self, dataset_id, document_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = self.get_document(dataset_id, document_id)
|
document = self.get_document(dataset_id, document_id)
|
||||||
@ -931,6 +939,7 @@ class DocumentStatusApi(DocumentResource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@ -1034,8 +1043,9 @@ class DocumentRetryApi(DocumentResource):
|
|||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
"""retry document."""
|
"""retry document."""
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json")
|
"document_ids", type=list, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -1077,14 +1087,14 @@ class DocumentRenameApi(DocumentResource):
|
|||||||
@marshal_with(document_fields)
|
@marshal_with(document_fields)
|
||||||
def post(self, dataset_id, document_id):
|
def post(self, dataset_id, document_id):
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset)
|
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1102,6 +1112,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id, document_id):
|
def get(self, dataset_id, document_id):
|
||||||
"""sync website document."""
|
"""sync website document."""
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
@ -1110,7 +1121,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
|||||||
document = DocumentService.get_document(dataset.id, document_id)
|
document = DocumentService.get_document(dataset.id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
if document.tenant_id != current_user.current_tenant_id:
|
if document.tenant_id != current_tenant_id:
|
||||||
raise Forbidden("No permission.")
|
raise Forbidden("No permission.")
|
||||||
if document.data_source_type != "website_crawl":
|
if document.data_source_type != "website_crawl":
|
||||||
raise ValueError("Document is not a website document.")
|
raise ValueError("Document is not a website document.")
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal, reqparse
|
from flask_restx import Resource, marshal, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.dataset import ChildChunk, DocumentSegment
|
from models.dataset import ChildChunk, DocumentSegment
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||||
@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id, document_id):
|
def get(self, dataset_id, document_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -59,13 +60,15 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("limit", type=int, default=20, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("status", type=str, action="append", default=[], location="args")
|
.add_argument("limit", type=int, default=20, location="args")
|
||||||
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
|
.add_argument("status", type=str, action="append", default=[], location="args")
|
||||||
parser.add_argument("enabled", type=str, default="all", location="args")
|
.add_argument("hit_count_gte", type=int, default=None, location="args")
|
||||||
parser.add_argument("keyword", type=str, default=None, location="args")
|
.add_argument("enabled", type=str, default="all", location="args")
|
||||||
parser.add_argument("page", type=int, default=1, location="args")
|
.add_argument("keyword", type=str, default=None, location="args")
|
||||||
|
.add_argument("page", type=int, default=1, location="args")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -79,7 +82,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
select(DocumentSegment)
|
select(DocumentSegment)
|
||||||
.where(
|
.where(
|
||||||
DocumentSegment.document_id == str(document_id),
|
DocumentSegment.document_id == str(document_id),
|
||||||
DocumentSegment.tenant_id == current_user.current_tenant_id,
|
DocumentSegment.tenant_id == current_tenant_id,
|
||||||
)
|
)
|
||||||
.order_by(DocumentSegment.position.asc())
|
.order_by(DocumentSegment.position.asc())
|
||||||
)
|
)
|
||||||
@ -115,6 +118,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def delete(self, dataset_id, document_id):
|
def delete(self, dataset_id, document_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -148,6 +153,8 @@ class DatasetDocumentSegmentApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, action):
|
def patch(self, dataset_id, document_id, action):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
@ -171,7 +178,7 @@ class DatasetDocumentSegmentApi(Resource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -204,6 +211,8 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self, dataset_id, document_id):
|
def post(self, dataset_id, document_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -221,7 +230,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -237,10 +246,12 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
SegmentService.segment_create_args_validate(args, document)
|
SegmentService.segment_create_args_validate(args, document)
|
||||||
segment = SegmentService.create_segment(args, document, dataset)
|
segment = SegmentService.create_segment(args, document, dataset)
|
||||||
@ -255,6 +266,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, segment_id):
|
def patch(self, dataset_id, document_id, segment_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -272,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -287,7 +300,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -300,12 +313,14 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument(
|
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||||
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
|
.add_argument(
|
||||||
|
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
SegmentService.segment_create_args_validate(args, document)
|
SegmentService.segment_create_args_validate(args, document)
|
||||||
@ -317,6 +332,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def delete(self, dataset_id, document_id, segment_id):
|
def delete(self, dataset_id, document_id, segment_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -333,7 +350,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -361,6 +378,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self, dataset_id, document_id):
|
def post(self, dataset_id, document_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -372,8 +391,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json")
|
"upload_file_id", type=str, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
upload_file_id = args["upload_file_id"]
|
upload_file_id = args["upload_file_id"]
|
||||||
|
|
||||||
@ -396,7 +416,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||||||
upload_file_id,
|
upload_file_id,
|
||||||
dataset_id,
|
dataset_id,
|
||||||
document_id,
|
document_id,
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
current_user.id,
|
current_user.id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -427,6 +447,8 @@ class ChildChunkAddApi(Resource):
|
|||||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self, dataset_id, document_id, segment_id):
|
def post(self, dataset_id, document_id, segment_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -441,7 +463,7 @@ class ChildChunkAddApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -453,7 +475,7 @@ class ChildChunkAddApi(Resource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -469,8 +491,9 @@ class ChildChunkAddApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
"content", type=str, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
try:
|
try:
|
||||||
content = args["content"]
|
content = args["content"]
|
||||||
@ -483,6 +506,8 @@ class ChildChunkAddApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id, document_id, segment_id):
|
def get(self, dataset_id, document_id, segment_id):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -499,15 +524,17 @@ class ChildChunkAddApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("limit", type=int, default=20, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("keyword", type=str, default=None, location="args")
|
.add_argument("limit", type=int, default=20, location="args")
|
||||||
parser.add_argument("page", type=int, default=1, location="args")
|
.add_argument("keyword", type=str, default=None, location="args")
|
||||||
|
.add_argument("page", type=int, default=1, location="args")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -530,6 +557,8 @@ class ChildChunkAddApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, segment_id):
|
def patch(self, dataset_id, document_id, segment_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -546,7 +575,7 @@ class ChildChunkAddApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -559,8 +588,9 @@ class ChildChunkAddApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
|
"chunks", type=list, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
try:
|
try:
|
||||||
chunks_data = args["chunks"]
|
chunks_data = args["chunks"]
|
||||||
@ -580,6 +610,8 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -596,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -607,7 +639,7 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
db.session.query(ChildChunk)
|
db.session.query(ChildChunk)
|
||||||
.where(
|
.where(
|
||||||
ChildChunk.id == str(child_chunk_id),
|
ChildChunk.id == str(child_chunk_id),
|
||||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
ChildChunk.tenant_id == current_tenant_id,
|
||||||
ChildChunk.segment_id == segment.id,
|
ChildChunk.segment_id == segment.id,
|
||||||
ChildChunk.document_id == document_id,
|
ChildChunk.document_id == document_id,
|
||||||
)
|
)
|
||||||
@ -634,6 +666,8 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -650,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -661,7 +695,7 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
db.session.query(ChildChunk)
|
db.session.query(ChildChunk)
|
||||||
.where(
|
.where(
|
||||||
ChildChunk.id == str(child_chunk_id),
|
ChildChunk.id == str(child_chunk_id),
|
||||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
ChildChunk.tenant_id == current_tenant_id,
|
||||||
ChildChunk.segment_id == segment.id,
|
ChildChunk.segment_id == segment.id,
|
||||||
ChildChunk.document_id == document_id,
|
ChildChunk.document_id == document_id,
|
||||||
)
|
)
|
||||||
@ -677,8 +711,9 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
"content", type=str, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
try:
|
try:
|
||||||
content = args["content"]
|
content = args["content"]
|
||||||
|
|||||||
@ -1,7 +1,4 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal, reqparse
|
from flask_restx import Resource, fields, marshal, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
@ -10,8 +7,7 @@ from controllers.console import api, console_ns
|
|||||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from fields.dataset_fields import dataset_detail_fields
|
from fields.dataset_fields import dataset_detail_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
from services.external_knowledge_service import ExternalDatasetService
|
from services.external_knowledge_service import ExternalDatasetService
|
||||||
from services.hit_testing_service import HitTestingService
|
from services.hit_testing_service import HitTestingService
|
||||||
@ -40,12 +36,13 @@ class ExternalApiTemplateListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
search = request.args.get("keyword", default=None, type=str)
|
search = request.args.get("keyword", default=None, type=str)
|
||||||
|
|
||||||
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
|
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
|
||||||
page, limit, current_user.current_tenant_id, search
|
page, limit, current_tenant_id, search
|
||||||
)
|
)
|
||||||
response = {
|
response = {
|
||||||
"data": [item.to_dict() for item in external_knowledge_apis],
|
"data": [item.to_dict() for item in external_knowledge_apis],
|
||||||
@ -60,20 +57,23 @@ class ExternalApiTemplateListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser.add_argument(
|
parser = (
|
||||||
"name",
|
reqparse.RequestParser()
|
||||||
nullable=False,
|
.add_argument(
|
||||||
required=True,
|
"name",
|
||||||
help="Name is required. Name must be between 1 to 100 characters.",
|
nullable=False,
|
||||||
type=_validate_name,
|
required=True,
|
||||||
)
|
help="Name is required. Name must be between 1 to 100 characters.",
|
||||||
parser.add_argument(
|
type=_validate_name,
|
||||||
"settings",
|
)
|
||||||
type=dict,
|
.add_argument(
|
||||||
location="json",
|
"settings",
|
||||||
nullable=False,
|
type=dict,
|
||||||
required=True,
|
location="json",
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -85,7 +85,7 @@ class ExternalApiTemplateListApi(Resource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
|
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
|
||||||
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args
|
tenant_id=current_tenant_id, user_id=current_user.id, args=args
|
||||||
)
|
)
|
||||||
except services.errors.dataset.DatasetNameDuplicateError:
|
except services.errors.dataset.DatasetNameDuplicateError:
|
||||||
raise DatasetNameDuplicateError()
|
raise DatasetNameDuplicateError()
|
||||||
@ -115,28 +115,31 @@ class ExternalApiTemplateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, external_knowledge_api_id):
|
def patch(self, external_knowledge_api_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"name",
|
.add_argument(
|
||||||
nullable=False,
|
"name",
|
||||||
required=True,
|
nullable=False,
|
||||||
help="type is required. Name must be between 1 to 100 characters.",
|
required=True,
|
||||||
type=_validate_name,
|
help="type is required. Name must be between 1 to 100 characters.",
|
||||||
)
|
type=_validate_name,
|
||||||
parser.add_argument(
|
)
|
||||||
"settings",
|
.add_argument(
|
||||||
type=dict,
|
"settings",
|
||||||
location="json",
|
type=dict,
|
||||||
nullable=False,
|
location="json",
|
||||||
required=True,
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ExternalDatasetService.validate_api_list(args["settings"])
|
ExternalDatasetService.validate_api_list(args["settings"])
|
||||||
|
|
||||||
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
external_knowledge_api_id=external_knowledge_api_id,
|
external_knowledge_api_id=external_knowledge_api_id,
|
||||||
args=args,
|
args=args,
|
||||||
@ -148,13 +151,13 @@ class ExternalApiTemplateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, external_knowledge_api_id):
|
def delete(self, external_knowledge_api_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||||
if not (current_user.is_editor or current_user.is_dataset_operator):
|
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
|
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
@ -199,21 +202,24 @@ class ExternalDatasetCreateApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument(
|
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
||||||
"name",
|
.add_argument(
|
||||||
nullable=False,
|
"name",
|
||||||
required=True,
|
nullable=False,
|
||||||
help="name is required. Name must be between 1 to 100 characters.",
|
required=True,
|
||||||
type=_validate_name,
|
help="name is required. Name must be between 1 to 100 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
.add_argument("description", type=str, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
|
|
||||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -223,7 +229,7 @@ class ExternalDatasetCreateApi(Resource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
dataset = ExternalDatasetService.create_external_dataset(
|
dataset = ExternalDatasetService.create_external_dataset(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
args=args,
|
args=args,
|
||||||
)
|
)
|
||||||
@ -255,6 +261,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@ -265,10 +272,12 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("query", type=str, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
.add_argument("query", type=str, location="json")
|
||||||
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||||
|
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
HitTestingService.hit_testing_args_check(args)
|
HitTestingService.hit_testing_args_check(args)
|
||||||
@ -277,7 +286,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||||||
response = HitTestingService.external_retrieve(
|
response = HitTestingService.external_retrieve(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
query=args["query"],
|
query=args["query"],
|
||||||
account=cast(Account, current_user),
|
account=current_user,
|
||||||
external_retrieval_model=args["external_retrieval_model"],
|
external_retrieval_model=args["external_retrieval_model"],
|
||||||
metadata_filtering_conditions=args["metadata_filtering_conditions"],
|
metadata_filtering_conditions=args["metadata_filtering_conditions"],
|
||||||
)
|
)
|
||||||
@ -304,15 +313,17 @@ class BedrockRetrievalApi(Resource):
|
|||||||
)
|
)
|
||||||
@api.response(200, "Bedrock retrieval test completed")
|
@api.response(200, "Bedrock retrieval test completed")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
||||||
"query",
|
.add_argument(
|
||||||
nullable=False,
|
"query",
|
||||||
required=True,
|
nullable=False,
|
||||||
type=str,
|
required=True,
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
||||||
)
|
)
|
||||||
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Call the knowledge retrieval service
|
# Call the knowledge retrieval service
|
||||||
|
|||||||
@ -48,11 +48,12 @@ class DatasetsHitTestingBase:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
|
reqparse.RequestParser()
|
||||||
parser.add_argument("query", type=str, location="json")
|
.add_argument("query", type=str, location="json")
|
||||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -1,13 +1,12 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||||
from fields.dataset_fields import dataset_metadata_fields
|
from fields.dataset_fields import dataset_metadata_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import (
|
from services.entities.knowledge_entities.knowledge_entities import (
|
||||||
MetadataArgs,
|
MetadataArgs,
|
||||||
@ -24,9 +23,12 @@ class DatasetMetadataCreateApi(Resource):
|
|||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
@marshal_with(dataset_metadata_fields)
|
@marshal_with(dataset_metadata_fields)
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
parser = reqparse.RequestParser()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
parser = (
|
||||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
|
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
metadata_args = MetadataArgs.model_validate(args)
|
metadata_args = MetadataArgs.model_validate(args)
|
||||||
|
|
||||||
@ -59,8 +61,8 @@ class DatasetMetadataApi(Resource):
|
|||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
@marshal_with(dataset_metadata_fields)
|
@marshal_with(dataset_metadata_fields)
|
||||||
def patch(self, dataset_id, metadata_id):
|
def patch(self, dataset_id, metadata_id):
|
||||||
parser = reqparse.RequestParser()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
name = args["name"]
|
name = args["name"]
|
||||||
|
|
||||||
@ -79,6 +81,7 @@ class DatasetMetadataApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def delete(self, dataset_id, metadata_id):
|
def delete(self, dataset_id, metadata_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
metadata_id_str = str(metadata_id)
|
metadata_id_str = str(metadata_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
@ -108,6 +111,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@ -128,14 +132,16 @@ class DocumentMetadataEditApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
|
"operation_data", type=list, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
metadata_args = MetadataOperationData.model_validate(args)
|
metadata_args = MetadataOperationData.model_validate(args)
|
||||||
|
|
||||||
|
|||||||
@ -1,19 +1,15 @@
|
|||||||
from flask import make_response, redirect, request
|
from flask import make_response, redirect, request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
account_initialization_required,
|
|
||||||
setup_required,
|
|
||||||
)
|
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from libs.helper import StrLen
|
from libs.helper import StrLen
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.provider_ids import DatasourceProviderID
|
from models.provider_ids import DatasourceProviderID
|
||||||
from services.datasource_provider_service import DatasourceProviderService
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
from services.plugin.oauth_service import OAuthProxyService
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
@ -24,11 +20,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def get(self, provider_id: str):
|
def get(self, provider_id: str):
|
||||||
user = current_user
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
if not current_user.is_editor:
|
tenant_id = current_tenant_id
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
credential_id = request.args.get("credential_id")
|
credential_id = request.args.get("credential_id")
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
@ -52,7 +48,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
|||||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||||
authorization_url_response = oauth_handler.get_authorization_url(
|
authorization_url_response = oauth_handler.get_authorization_url(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user.id,
|
user_id=current_user.id,
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
redirect_uri=redirect_uri,
|
redirect_uri=redirect_uri,
|
||||||
@ -130,22 +126,24 @@ class DatasourceAuth(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
if not current_user.is_editor:
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
|
.add_argument(
|
||||||
|
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
|
||||||
|
)
|
||||||
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
datasource_provider_service.add_datasource_api_key_provider(
|
datasource_provider_service.add_datasource_api_key_provider(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider_id=datasource_provider_id,
|
provider_id=datasource_provider_id,
|
||||||
credentials=args["credentials"],
|
credentials=args["credentials"],
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
@ -160,8 +158,10 @@ class DatasourceAuth(Resource):
|
|||||||
def get(self, provider_id: str):
|
def get(self, provider_id: str):
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasources = datasource_provider_service.list_datasource_credentials(
|
datasources = datasource_provider_service.list_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=datasource_provider_id.provider_name,
|
provider=datasource_provider_id.provider_name,
|
||||||
plugin_id=datasource_provider_id.plugin_id,
|
plugin_id=datasource_provider_id.plugin_id,
|
||||||
)
|
)
|
||||||
@ -173,18 +173,21 @@ class DatasourceAuthDeleteApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
plugin_id = datasource_provider_id.plugin_id
|
plugin_id = datasource_provider_id.plugin_id
|
||||||
provider_name = datasource_provider_id.provider_name
|
provider_name = datasource_provider_id.provider_name
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser = reqparse.RequestParser()
|
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.remove_datasource_credentials(
|
datasource_provider_service.remove_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
auth_id=args["credential_id"],
|
auth_id=args["credential_id"],
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
@ -197,18 +200,22 @@ class DatasourceAuthUpdateApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||||
|
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.update_datasource_credentials(
|
datasource_provider_service.update_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
auth_id=args["credential_id"],
|
auth_id=args["credential_id"],
|
||||||
provider=datasource_provider_id.provider_name,
|
provider=datasource_provider_id.provider_name,
|
||||||
plugin_id=datasource_provider_id.plugin_id,
|
plugin_id=datasource_provider_id.plugin_id,
|
||||||
@ -224,10 +231,10 @@ class DatasourceAuthListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasources = datasource_provider_service.get_all_datasource_credentials(
|
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
|
||||||
tenant_id=current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
return {"result": jsonable_encoder(datasources)}, 200
|
return {"result": jsonable_encoder(datasources)}, 200
|
||||||
|
|
||||||
|
|
||||||
@ -237,10 +244,10 @@ class DatasourceHardCodeAuthListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
|
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
|
||||||
tenant_id=current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
return {"result": jsonable_encoder(datasources)}, 200
|
return {"result": jsonable_encoder(datasources)}, 200
|
||||||
|
|
||||||
|
|
||||||
@ -249,17 +256,20 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
if not current_user.is_editor:
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.setup_oauth_custom_client_params(
|
datasource_provider_service.setup_oauth_custom_client_params(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
client_params=args.get("client_params", {}),
|
client_params=args.get("client_params", {}),
|
||||||
enabled=args.get("enable_oauth_custom_client", False),
|
enabled=args.get("enable_oauth_custom_client", False),
|
||||||
@ -270,10 +280,12 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider_id: str):
|
def delete(self, provider_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.remove_oauth_custom_client_params(
|
datasource_provider_service.remove_oauth_custom_client_params(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
)
|
)
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
@ -284,16 +296,16 @@ class DatasourceAuthDefaultApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
if not current_user.is_editor:
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.set_default_datasource_provider(
|
datasource_provider_service.set_default_datasource_provider(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
credential_id=args["id"],
|
credential_id=args["id"],
|
||||||
)
|
)
|
||||||
@ -305,17 +317,20 @@ class DatasourceUpdateProviderNameApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
if not current_user.is_editor:
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||||
|
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.update_datasource_provider_name(
|
datasource_provider_service.update_datasource_provider_name(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
credential_id=args["credential_id"],
|
credential_id=args["credential_id"],
|
||||||
|
|||||||
@ -26,10 +26,12 @@ class DataSourceContentPreviewApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
.add_argument("credential_id", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
inputs = args.get("inputs")
|
inputs = args.get("inputs")
|
||||||
|
|||||||
@ -66,26 +66,28 @@ class CustomizedPipelineTemplateApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def patch(self, template_id: str):
|
def patch(self, template_id: str):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"name",
|
.add_argument(
|
||||||
nullable=False,
|
"name",
|
||||||
required=True,
|
nullable=False,
|
||||||
help="Name must be between 1 to 40 characters.",
|
required=True,
|
||||||
type=_validate_name,
|
help="Name must be between 1 to 40 characters.",
|
||||||
)
|
type=_validate_name,
|
||||||
parser.add_argument(
|
)
|
||||||
"description",
|
.add_argument(
|
||||||
type=_validate_description_length,
|
"description",
|
||||||
nullable=True,
|
type=_validate_description_length,
|
||||||
required=False,
|
nullable=True,
|
||||||
default="",
|
required=False,
|
||||||
)
|
default="",
|
||||||
parser.add_argument(
|
)
|
||||||
"icon_info",
|
.add_argument(
|
||||||
type=dict,
|
"icon_info",
|
||||||
location="json",
|
type=dict,
|
||||||
nullable=True,
|
location="json",
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
|
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
|
||||||
@ -123,26 +125,28 @@ class PublishCustomizedPipelineTemplateApi(Resource):
|
|||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
@knowledge_pipeline_publish_enabled
|
@knowledge_pipeline_publish_enabled
|
||||||
def post(self, pipeline_id: str):
|
def post(self, pipeline_id: str):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"name",
|
.add_argument(
|
||||||
nullable=False,
|
"name",
|
||||||
required=True,
|
nullable=False,
|
||||||
help="Name must be between 1 to 40 characters.",
|
required=True,
|
||||||
type=_validate_name,
|
help="Name must be between 1 to 40 characters.",
|
||||||
)
|
type=_validate_name,
|
||||||
parser.add_argument(
|
)
|
||||||
"description",
|
.add_argument(
|
||||||
type=_validate_description_length,
|
"description",
|
||||||
nullable=True,
|
type=_validate_description_length,
|
||||||
required=False,
|
nullable=True,
|
||||||
default="",
|
required=False,
|
||||||
)
|
default="",
|
||||||
parser.add_argument(
|
)
|
||||||
"icon_info",
|
.add_argument(
|
||||||
type=dict,
|
"icon_info",
|
||||||
location="json",
|
type=dict,
|
||||||
nullable=True,
|
location="json",
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal, reqparse
|
from flask_restx import Resource, marshal, reqparse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
@ -13,7 +12,7 @@ from controllers.console.wraps import (
|
|||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.dataset_fields import dataset_detail_fields
|
from fields.dataset_fields import dataset_detail_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.dataset import DatasetPermissionEnum
|
from models.dataset import DatasetPermissionEnum
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||||
@ -27,9 +26,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"yaml_content",
|
"yaml_content",
|
||||||
type=str,
|
type=str,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
@ -38,7 +35,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
@ -58,12 +55,12 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||||
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||||
)
|
)
|
||||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||||
DatasetPermissionService.update_partial_member_list(
|
DatasetPermissionService.update_partial_member_list(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
import_info["dataset_id"],
|
import_info["dataset_id"],
|
||||||
rag_pipeline_dataset_create_entity.partial_member_list,
|
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||||
)
|
)
|
||||||
@ -81,10 +78,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
|||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
||||||
name="",
|
name="",
|
||||||
description="",
|
description="",
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from extensions.ext_database import db
|
|||||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||||
from factories.variable_factory import build_segment_with_type
|
from factories.variable_factory import build_segment_with_type
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_user, login_required
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
from models.workflow import WorkflowDraftVariable
|
from models.workflow import WorkflowDraftVariable
|
||||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
@ -33,16 +33,18 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def _create_pagination_parser():
|
def _create_pagination_parser():
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"page",
|
.add_argument(
|
||||||
type=inputs.int_range(1, 100_000),
|
"page",
|
||||||
required=False,
|
type=inputs.int_range(1, 100_000),
|
||||||
default=1,
|
required=False,
|
||||||
location="args",
|
default=1,
|
||||||
help="the page of data requested",
|
location="args",
|
||||||
|
help="the page of data requested",
|
||||||
|
)
|
||||||
|
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
)
|
)
|
||||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -206,10 +208,11 @@ class RagPipelineVariableApi(Resource):
|
|||||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||||
# }
|
# }
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
# Parse 'value' field as-is to maintain its original data structure
|
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
|
|
||||||
draft_var_srv = WorkflowDraftVariableService(
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
session=db.session(),
|
session=db.session(),
|
||||||
|
|||||||
@ -1,6 +1,3 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
from flask_login import current_user # type: ignore
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
@ -13,8 +10,7 @@ from controllers.console.wraps import (
|
|||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account
|
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
from services.app_dsl_service import ImportStatus
|
from services.app_dsl_service import ImportStatus
|
||||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
@ -28,26 +24,29 @@ class RagPipelineImportApi(Resource):
|
|||||||
@marshal_with(pipeline_import_fields)
|
@marshal_with(pipeline_import_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
# Check user role first
|
# Check user role first
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("mode", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("yaml_content", type=str, location="json")
|
.add_argument("mode", type=str, required=True, location="json")
|
||||||
parser.add_argument("yaml_url", type=str, location="json")
|
.add_argument("yaml_content", type=str, location="json")
|
||||||
parser.add_argument("name", type=str, location="json")
|
.add_argument("yaml_url", type=str, location="json")
|
||||||
parser.add_argument("description", type=str, location="json")
|
.add_argument("name", type=str, location="json")
|
||||||
parser.add_argument("icon_type", type=str, location="json")
|
.add_argument("description", type=str, location="json")
|
||||||
parser.add_argument("icon", type=str, location="json")
|
.add_argument("icon_type", type=str, location="json")
|
||||||
parser.add_argument("icon_background", type=str, location="json")
|
.add_argument("icon", type=str, location="json")
|
||||||
parser.add_argument("pipeline_id", type=str, location="json")
|
.add_argument("icon_background", type=str, location="json")
|
||||||
|
.add_argument("pipeline_id", type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Create service with session
|
# Create service with session
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = RagPipelineDslService(session)
|
import_service = RagPipelineDslService(session)
|
||||||
# Import app
|
# Import app
|
||||||
account = cast(Account, current_user)
|
account = current_user
|
||||||
result = import_service.import_rag_pipeline(
|
result = import_service.import_rag_pipeline(
|
||||||
account=account,
|
account=account,
|
||||||
import_mode=args["mode"],
|
import_mode=args["mode"],
|
||||||
@ -74,15 +73,16 @@ class RagPipelineImportConfirmApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(pipeline_import_fields)
|
@marshal_with(pipeline_import_fields)
|
||||||
def post(self, import_id):
|
def post(self, import_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
# Check user role first
|
# Check user role first
|
||||||
if not current_user.is_editor:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
# Create service with session
|
# Create service with session
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = RagPipelineDslService(session)
|
import_service = RagPipelineDslService(session)
|
||||||
# Confirm import
|
# Confirm import
|
||||||
account = cast(Account, current_user)
|
account = current_user
|
||||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
@ -100,7 +100,8 @@ class RagPipelineImportCheckDependenciesApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(pipeline_import_check_dependencies_fields)
|
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
@ -117,12 +118,12 @@ class RagPipelineExportApi(Resource):
|
|||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
# Add include_secret params
|
# Add include_secret params
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
|
||||||
parser.add_argument("include_secret", type=str, default="false", location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from controllers.console.app.error import (
|
|||||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
|
edit_permission_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||||
@ -36,8 +37,8 @@ from fields.workflow_run_fields import (
|
|||||||
)
|
)
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, current_user, login_required
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
from services.errors.app import WorkflowHashNotEqualError
|
from services.errors.app import WorkflowHashNotEqualError
|
||||||
@ -56,15 +57,12 @@ class DraftRagPipelineApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
@marshal_with(workflow_fields)
|
@marshal_with(workflow_fields)
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Get draft rag pipeline's workflow
|
Get draft rag pipeline's workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# fetch draft workflow by app_model
|
# fetch draft workflow by app_model
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||||
@ -79,23 +77,25 @@ class DraftRagPipelineApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
def post(self, pipeline: Pipeline):
|
def post(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Sync draft workflow
|
Sync draft workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
content_type = request.headers.get("Content-Type", "")
|
content_type = request.headers.get("Content-Type", "")
|
||||||
|
|
||||||
if "application/json" in content_type:
|
if "application/json" in content_type:
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("hash", type=str, required=False, location="json")
|
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
.add_argument("hash", type=str, required=False, location="json")
|
||||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
.add_argument("environment_variables", type=list, required=False, location="json")
|
||||||
parser.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
|
.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||||
|
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
elif "text/plain" in content_type:
|
elif "text/plain" in content_type:
|
||||||
try:
|
try:
|
||||||
@ -154,16 +154,15 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
def post(self, pipeline: Pipeline, node_id: str):
|
def post(self, pipeline: Pipeline, node_id: str):
|
||||||
"""
|
"""
|
||||||
Run draft workflow iteration node
|
Run draft workflow iteration node
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -194,11 +193,11 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||||||
Run draft workflow loop node
|
Run draft workflow loop node
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -229,14 +228,17 @@ class DraftRagPipelineRunApi(Resource):
|
|||||||
Run draft workflow
|
Run draft workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
|
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||||
|
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -264,17 +266,20 @@ class PublishedRagPipelineRunApi(Resource):
|
|||||||
Run published workflow
|
Run published workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
|
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||||
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||||
parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
||||||
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
||||||
|
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args["response_mode"] == "streaming"
|
streaming = args["response_mode"] == "streaming"
|
||||||
@ -303,15 +308,16 @@ class PublishedRagPipelineRunApi(Resource):
|
|||||||
# Run rag pipeline datasource
|
# Run rag pipeline datasource
|
||||||
# """
|
# """
|
||||||
# # The role of the current user in the ta table must be admin, owner, or editor
|
# # The role of the current user in the ta table must be admin, owner, or editor
|
||||||
# if not current_user.is_editor:
|
# if not current_user.has_edit_permission:
|
||||||
# raise Forbidden()
|
# raise Forbidden()
|
||||||
#
|
#
|
||||||
# if not isinstance(current_user, Account):
|
# if not isinstance(current_user, Account):
|
||||||
# raise Forbidden()
|
# raise Forbidden()
|
||||||
#
|
#
|
||||||
# parser = reqparse.RequestParser()
|
# parser = (reqparse.RequestParser()
|
||||||
# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
||||||
# parser.add_argument("datasource_type", type=str, required=True, location="json")
|
# .add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
# )
|
||||||
# args = parser.parse_args()
|
# args = parser.parse_args()
|
||||||
#
|
#
|
||||||
# job_id = args.get("job_id")
|
# job_id = args.get("job_id")
|
||||||
@ -344,15 +350,16 @@ class PublishedRagPipelineRunApi(Resource):
|
|||||||
# Run rag pipeline datasource
|
# Run rag pipeline datasource
|
||||||
# """
|
# """
|
||||||
# # The role of the current user in the ta table must be admin, owner, or editor
|
# # The role of the current user in the ta table must be admin, owner, or editor
|
||||||
# if not current_user.is_editor:
|
# if not current_user.has_edit_permission:
|
||||||
# raise Forbidden()
|
# raise Forbidden()
|
||||||
#
|
#
|
||||||
# if not isinstance(current_user, Account):
|
# if not isinstance(current_user, Account):
|
||||||
# raise Forbidden()
|
# raise Forbidden()
|
||||||
#
|
#
|
||||||
# parser = reqparse.RequestParser()
|
# parser = (reqparse.RequestParser()
|
||||||
# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
||||||
# parser.add_argument("datasource_type", type=str, required=True, location="json")
|
# .add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
# )
|
||||||
# args = parser.parse_args()
|
# args = parser.parse_args()
|
||||||
#
|
#
|
||||||
# job_id = args.get("job_id")
|
# job_id = args.get("job_id")
|
||||||
@ -385,13 +392,16 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||||||
Run rag pipeline datasource
|
Run rag pipeline datasource
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
.add_argument("credential_id", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
inputs = args.get("inputs")
|
inputs = args.get("inputs")
|
||||||
@ -428,13 +438,16 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||||||
Run rag pipeline datasource
|
Run rag pipeline datasource
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
.add_argument("credential_id", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
inputs = args.get("inputs")
|
inputs = args.get("inputs")
|
||||||
@ -472,11 +485,13 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||||||
Run draft workflow node
|
Run draft workflow node
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
"inputs", type=dict, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
inputs = args.get("inputs")
|
inputs = args.get("inputs")
|
||||||
@ -505,7 +520,8 @@ class RagPipelineTaskStopApi(Resource):
|
|||||||
Stop workflow task
|
Stop workflow task
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||||
@ -525,7 +541,8 @@ class PublishedRagPipelineApi(Resource):
|
|||||||
Get published pipeline
|
Get published pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
if not pipeline.is_published:
|
if not pipeline.is_published:
|
||||||
return None
|
return None
|
||||||
@ -545,7 +562,8 @@ class PublishedRagPipelineApi(Resource):
|
|||||||
Publish workflow
|
Publish workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
@ -580,7 +598,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
|||||||
Get default block config
|
Get default block config
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
# Get default block configs
|
# Get default block configs
|
||||||
@ -599,11 +618,11 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
|||||||
Get default block config
|
Get default block config
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||||
parser.add_argument("q", type=str, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
q = args.get("q")
|
q = args.get("q")
|
||||||
@ -631,14 +650,17 @@ class PublishedAllRagPipelineApi(Resource):
|
|||||||
"""
|
"""
|
||||||
Get published workflows
|
Get published workflows
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||||
parser.add_argument("user_id", type=str, required=False, location="args")
|
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
.add_argument("user_id", type=str, required=False, location="args")
|
||||||
|
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
page = int(args.get("page", 1))
|
page = int(args.get("page", 1))
|
||||||
limit = int(args.get("limit", 10))
|
limit = int(args.get("limit", 10))
|
||||||
@ -681,12 +703,15 @@ class RagPipelineByIdApi(Resource):
|
|||||||
Update workflow attributes
|
Update workflow attributes
|
||||||
"""
|
"""
|
||||||
# Check permission
|
# Check permission
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("marked_name", type=str, required=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("marked_comment", type=str, required=False, location="json")
|
.add_argument("marked_name", type=str, required=False, location="json")
|
||||||
|
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate name and comment length
|
# Validate name and comment length
|
||||||
@ -733,15 +758,12 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Get second step parameters of rag pipeline
|
Get second step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
node_id = args.get("node_id")
|
node_id = args.get("node_id")
|
||||||
if not node_id:
|
if not node_id:
|
||||||
@ -759,15 +781,12 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Get first step parameters of rag pipeline
|
Get first step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
node_id = args.get("node_id")
|
node_id = args.get("node_id")
|
||||||
if not node_id:
|
if not node_id:
|
||||||
@ -785,15 +804,12 @@ class DraftRagPipelineFirstStepApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Get first step parameters of rag pipeline
|
Get first step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
node_id = args.get("node_id")
|
node_id = args.get("node_id")
|
||||||
if not node_id:
|
if not node_id:
|
||||||
@ -811,15 +827,12 @@ class DraftRagPipelineSecondStepApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Get second step parameters of rag pipeline
|
Get second step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
node_id = args.get("node_id")
|
node_id = args.get("node_id")
|
||||||
if not node_id:
|
if not node_id:
|
||||||
@ -843,9 +856,11 @@ class RagPipelineWorkflowRunListApi(Resource):
|
|||||||
"""
|
"""
|
||||||
Get workflow run list
|
Get workflow run list
|
||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
.add_argument("last_id", type=uuid_value, location="args")
|
||||||
|
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
@ -880,7 +895,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
@marshal_with(workflow_run_node_execution_list_fields)
|
@marshal_with(workflow_run_node_execution_list_fields)
|
||||||
def get(self, pipeline: Pipeline, run_id):
|
def get(self, pipeline: Pipeline, run_id: str):
|
||||||
"""
|
"""
|
||||||
Get workflow run node execution list
|
Get workflow run node execution list
|
||||||
"""
|
"""
|
||||||
@ -903,14 +918,8 @@ class DatasourceListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
if not isinstance(user, Account):
|
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(current_tenant_id))
|
||||||
raise Forbidden()
|
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
if not tenant_id:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||||
@ -940,9 +949,8 @@ class RagPipelineTransformApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id: str):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
@ -959,19 +967,20 @@ class RagPipelineDatasourceVariableApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
@marshal_with(workflow_run_node_execution_fields)
|
@marshal_with(workflow_run_node_execution_fields)
|
||||||
def post(self, pipeline: Pipeline):
|
def post(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Set datasource variables
|
Set datasource variables
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = (
|
||||||
|
reqparse.RequestParser()
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
.add_argument("datasource_info", type=dict, required=True, location="json")
|
||||||
parser.add_argument("datasource_info", type=dict, required=True, location="json")
|
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
.add_argument("start_node_title", type=str, required=True, location="json")
|
||||||
parser.add_argument("start_node_title", type=str, required=True, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
|||||||
@ -31,17 +31,19 @@ class WebsiteCrawlApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"provider",
|
.add_argument(
|
||||||
type=str,
|
"provider",
|
||||||
choices=["firecrawl", "watercrawl", "jinareader"],
|
type=str,
|
||||||
required=True,
|
choices=["firecrawl", "watercrawl", "jinareader"],
|
||||||
nullable=True,
|
required=True,
|
||||||
location="json",
|
nullable=True,
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
.add_argument("url", type=str, required=True, nullable=True, location="json")
|
||||||
|
.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
|
|
||||||
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Create typed request and validate
|
# Create typed request and validate
|
||||||
@ -70,8 +72,7 @@ class WebsiteCrawlStatusApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, job_id: str):
|
def get(self, job_id: str):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument(
|
|
||||||
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
|
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -3,8 +3,7 @@ from functools import wraps
|
|||||||
|
|
||||||
from controllers.console.datasets.error import PipelineNotFoundError
|
from controllers.console.datasets.error import PipelineNotFoundError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models.account import Account
|
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
|
|
||||||
|
|
||||||
@ -17,8 +16,7 @@ def get_rag_pipeline(
|
|||||||
if not kwargs.get("pipeline_id"):
|
if not kwargs.get("pipeline_id"):
|
||||||
raise ValueError("missing pipeline_id in path parameters")
|
raise ValueError("missing pipeline_id in path parameters")
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("current_user is not an account")
|
|
||||||
|
|
||||||
pipeline_id = kwargs.get("pipeline_id")
|
pipeline_id = kwargs.get("pipeline_id")
|
||||||
pipeline_id = str(pipeline_id)
|
pipeline_id = str(pipeline_id)
|
||||||
@ -27,7 +25,7 @@ def get_rag_pipeline(
|
|||||||
|
|
||||||
pipeline = (
|
pipeline = (
|
||||||
db.session.query(Pipeline)
|
db.session.query(Pipeline)
|
||||||
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
|
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -81,11 +81,13 @@ class ChatTextApi(InstalledAppResource):
|
|||||||
|
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("voice", type=str, location="json")
|
.add_argument("message_id", type=str, required=False, location="json")
|
||||||
parser.add_argument("text", type=str, location="json")
|
.add_argument("voice", type=str, location="json")
|
||||||
parser.add_argument("streaming", type=bool, location="json")
|
.add_argument("text", type=str, location="json")
|
||||||
|
.add_argument("streaming", type=bool, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = args.get("message_id", None)
|
message_id = args.get("message_id", None)
|
||||||
|
|||||||
@ -49,12 +49,14 @@ class CompletionApi(InstalledAppResource):
|
|||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("query", type=str, location="json", default="")
|
.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
.add_argument("query", type=str, location="json", default="")
|
||||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
|
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args["response_mode"] == "streaming"
|
streaming = args["response_mode"] == "streaming"
|
||||||
@ -121,13 +123,15 @@ class ChatApi(InstalledAppResource):
|
|||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("query", type=str, required=True, location="json")
|
.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
.add_argument("query", type=str, required=True, location="json")
|
||||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||||
|
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args["auto_generate_name"] = False
|
args["auto_generate_name"] = False
|
||||||
|
|||||||
@ -31,10 +31,12 @@ class ConversationListApi(InstalledAppResource):
|
|||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
.add_argument("last_id", type=uuid_value, location="args")
|
||||||
parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
|
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
pinned = None
|
pinned = None
|
||||||
@ -94,9 +96,11 @@ class ConversationRenameApi(InstalledAppResource):
|
|||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("name", type=str, required=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
.add_argument("name", type=str, required=False, location="json")
|
||||||
|
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -12,10 +12,9 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.installed_app_fields import installed_app_list_fields
|
from fields.installed_app_fields import installed_app_list_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account, App, InstalledApp, RecommendedApp
|
from models import App, InstalledApp, RecommendedApp
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.app_service import AppService
|
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
@ -29,9 +28,7 @@ class InstalledAppsListApi(Resource):
|
|||||||
@marshal_with(installed_app_list_fields)
|
@marshal_with(installed_app_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
app_id = request.args.get("app_id", default=None, type=str)
|
app_id = request.args.get("app_id", default=None, type=str)
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
current_tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
if app_id:
|
if app_id:
|
||||||
installed_apps = db.session.scalars(
|
installed_apps = db.session.scalars(
|
||||||
@ -69,31 +66,26 @@ class InstalledAppsListApi(Resource):
|
|||||||
|
|
||||||
# Pre-filter out apps without setting or with sso_verified
|
# Pre-filter out apps without setting or with sso_verified
|
||||||
filtered_installed_apps = []
|
filtered_installed_apps = []
|
||||||
app_id_to_app_code = {}
|
|
||||||
|
|
||||||
for installed_app in installed_app_list:
|
for installed_app in installed_app_list:
|
||||||
app_id = installed_app["app"].id
|
app_id = installed_app["app"].id
|
||||||
webapp_setting = webapp_settings.get(app_id)
|
webapp_setting = webapp_settings.get(app_id)
|
||||||
if not webapp_setting or webapp_setting.access_mode == "sso_verified":
|
if not webapp_setting or webapp_setting.access_mode == "sso_verified":
|
||||||
continue
|
continue
|
||||||
app_code = AppService.get_app_code_by_id(str(app_id))
|
|
||||||
app_id_to_app_code[app_id] = app_code
|
|
||||||
filtered_installed_apps.append(installed_app)
|
filtered_installed_apps.append(installed_app)
|
||||||
|
|
||||||
app_codes = list(app_id_to_app_code.values())
|
|
||||||
|
|
||||||
# Batch permission check
|
# Batch permission check
|
||||||
|
app_ids = [installed_app["app"].id for installed_app in filtered_installed_apps]
|
||||||
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
|
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
app_codes=app_codes,
|
app_ids=app_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keep only allowed apps
|
# Keep only allowed apps
|
||||||
res = []
|
res = []
|
||||||
for installed_app in filtered_installed_apps:
|
for installed_app in filtered_installed_apps:
|
||||||
app_id = installed_app["app"].id
|
app_id = installed_app["app"].id
|
||||||
app_code = app_id_to_app_code[app_id]
|
if permissions.get(app_id):
|
||||||
if permissions.get(app_code):
|
|
||||||
res.append(installed_app)
|
res.append(installed_app)
|
||||||
|
|
||||||
installed_app_list = res
|
installed_app_list = res
|
||||||
@ -113,17 +105,15 @@ class InstalledAppsListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("apps")
|
@cloud_edition_billing_resource_check("apps")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
|
||||||
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
|
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
|
||||||
if recommended_app is None:
|
if recommended_app is None:
|
||||||
raise NotFound("App not found")
|
raise NotFound("App not found")
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
current_tenant_id = current_user.current_tenant_id
|
|
||||||
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
||||||
|
|
||||||
if app is None:
|
if app is None:
|
||||||
@ -163,9 +153,8 @@ class InstalledAppApi(InstalledAppResource):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def delete(self, installed_app):
|
def delete(self, installed_app):
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("current_user must be an Account instance")
|
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||||
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
|
||||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||||
|
|
||||||
db.session.delete(installed_app)
|
db.session.delete(installed_app)
|
||||||
@ -174,8 +163,7 @@ class InstalledAppApi(InstalledAppResource):
|
|||||||
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
||||||
|
|
||||||
def patch(self, installed_app):
|
def patch(self, installed_app):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
|
||||||
parser.add_argument("is_pinned", type=inputs.boolean)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
commit_args = False
|
commit_args = False
|
||||||
|
|||||||
@ -23,8 +23,7 @@ from core.model_runtime.errors.invoke import InvokeError
|
|||||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models import Account
|
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.app import MoreLikeThisDisabledError
|
from services.errors.app import MoreLikeThisDisabledError
|
||||||
@ -48,21 +47,22 @@ logger = logging.getLogger(__name__)
|
|||||||
class MessageListApi(InstalledAppResource):
|
class MessageListApi(InstalledAppResource):
|
||||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||||
def get(self, installed_app):
|
def get(self, installed_app):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
.add_argument("first_id", type=uuid_value, location="args")
|
||||||
|
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
return MessageService.pagination_by_first_id(
|
return MessageService.pagination_by_first_id(
|
||||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||||
)
|
)
|
||||||
@ -78,18 +78,19 @@ class MessageListApi(InstalledAppResource):
|
|||||||
)
|
)
|
||||||
class MessageFeedbackApi(InstalledAppResource):
|
class MessageFeedbackApi(InstalledAppResource):
|
||||||
def post(self, installed_app, message_id):
|
def post(self, installed_app, message_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("content", type=str, location="json")
|
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||||
|
.add_argument("content", type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
MessageService.create_feedback(
|
MessageService.create_feedback(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
@ -109,14 +110,14 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||||||
)
|
)
|
||||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||||
def get(self, installed_app, message_id):
|
def get(self, installed_app, message_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument(
|
|
||||||
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
|
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -124,8 +125,6 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
|||||||
streaming = args["response_mode"] == "streaming"
|
streaming = args["response_mode"] == "streaming"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
response = AppGenerateService.generate_more_like_this(
|
response = AppGenerateService.generate_more_like_this(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
@ -159,6 +158,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
|||||||
)
|
)
|
||||||
class MessageSuggestedQuestionApi(InstalledAppResource):
|
class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||||
def get(self, installed_app, message_id):
|
def get(self, installed_app, message_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
@ -167,8 +167,6 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
questions = MessageService.get_suggested_questions_after_answer(
|
questions = MessageService.get_suggested_questions_after_answer(
|
||||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||||
)
|
)
|
||||||
|
|||||||
@ -42,8 +42,7 @@ class RecommendedAppListApi(Resource):
|
|||||||
@marshal_with(recommended_app_list_fields)
|
@marshal_with(recommended_app_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
# language args
|
# language args
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("language", type=str, location="args")
|
||||||
parser.add_argument("language", type=str, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
language = args.get("language")
|
language = args.get("language")
|
||||||
|
|||||||
@ -7,8 +7,7 @@ from controllers.console.explore.error import NotCompletionAppError
|
|||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from fields.conversation_fields import message_file_fields
|
from fields.conversation_fields import message_file_fields
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models import Account
|
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
from services.saved_message_service import SavedMessageService
|
from services.saved_message_service import SavedMessageService
|
||||||
|
|
||||||
@ -35,31 +34,30 @@ class SavedMessageListApi(InstalledAppResource):
|
|||||||
|
|
||||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||||
def get(self, installed_app):
|
def get(self, installed_app):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
.add_argument("last_id", type=uuid_value, location="args")
|
||||||
|
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
|
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
|
||||||
|
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
|
||||||
parser.add_argument("message_id", type=uuid_value, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
SavedMessageService.save(app_model, current_user, args["message_id"])
|
SavedMessageService.save(app_model, current_user, args["message_id"])
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
@ -72,6 +70,7 @@ class SavedMessageListApi(InstalledAppResource):
|
|||||||
)
|
)
|
||||||
class SavedMessageApi(InstalledAppResource):
|
class SavedMessageApi(InstalledAppResource):
|
||||||
def delete(self, installed_app, message_id):
|
def delete(self, installed_app, message_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
@ -79,8 +78,6 @@ class SavedMessageApi(InstalledAppResource):
|
|||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
SavedMessageService.delete(app_model, current_user, message_id)
|
SavedMessageService.delete(app_model, current_user, message_id)
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from core.errors.error import (
|
|||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.login import current_user
|
from libs.login import current_user as current_user_
|
||||||
from models.model import AppMode, InstalledApp
|
from models.model import AppMode, InstalledApp
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
@ -31,6 +31,8 @@ from .. import console_ns
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
current_user = current_user_._get_current_object() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||||
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||||
@ -45,9 +47,11 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
|||||||
if app_mode != AppMode.WORKFLOW:
|
if app_mode != AppMode.WORKFLOW:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert current_user is not None
|
assert current_user is not None
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -8,10 +8,8 @@ from werkzeug.exceptions import NotFound
|
|||||||
from controllers.console.explore.error import AppAccessDeniedError
|
from controllers.console.explore.error import AppAccessDeniedError
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import InstalledApp
|
from models import InstalledApp
|
||||||
from models.account import Account
|
|
||||||
from services.app_service import AppService
|
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
@ -24,13 +22,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
|||||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
installed_app = (
|
installed_app = (
|
||||||
db.session.query(InstalledApp)
|
db.session.query(InstalledApp)
|
||||||
.where(
|
.where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
|
||||||
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -56,14 +51,13 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
|||||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
feature = FeatureService.get_system_features()
|
feature = FeatureService.get_system_features()
|
||||||
if feature.webapp_auth.enabled:
|
if feature.webapp_auth.enabled:
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
app_id = installed_app.app_id
|
app_id = installed_app.app_id
|
||||||
app_code = AppService.get_app_code_by_id(app_id)
|
|
||||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||||
user_id=str(current_user.id),
|
user_id=str(current_user.id),
|
||||||
app_code=app_code,
|
app_id=app_id,
|
||||||
)
|
)
|
||||||
if not res:
|
if not res:
|
||||||
raise AppAccessDeniedError()
|
raise AppAccessDeniedError()
|
||||||
|
|||||||
@ -4,8 +4,7 @@ from constants import HIDDEN_VALUE
|
|||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from fields.api_based_extension_fields import api_based_extension_fields
|
from fields.api_based_extension_fields import api_based_extension_fields
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from models.api_based_extension import APIBasedExtension
|
from models.api_based_extension import APIBasedExtension
|
||||||
from services.api_based_extension_service import APIBasedExtensionService
|
from services.api_based_extension_service import APIBasedExtensionService
|
||||||
from services.code_based_extension_service import CodeBasedExtensionService
|
from services.code_based_extension_service import CodeBasedExtensionService
|
||||||
@ -30,8 +29,7 @@ class CodeBasedExtensionAPI(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args")
|
||||||
parser.add_argument("module", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
|
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
|
||||||
@ -47,9 +45,7 @@ class APIBasedExtensionAPI(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
assert isinstance(current_user, Account)
|
_, tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||||
|
|
||||||
@api.doc("create_api_based_extension")
|
@api.doc("create_api_based_extension")
|
||||||
@ -70,16 +66,17 @@ class APIBasedExtensionAPI(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
assert isinstance(current_user, Account)
|
parser = (
|
||||||
assert current_user.current_tenant_id is not None
|
reqparse.RequestParser()
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("name", type=str, required=True, location="json")
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
.add_argument("api_key", type=str, required=True, location="json")
|
||||||
parser.add_argument("api_key", type=str, required=True, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
extension_data = APIBasedExtension(
|
extension_data = APIBasedExtension(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
api_endpoint=args["api_endpoint"],
|
api_endpoint=args["api_endpoint"],
|
||||||
api_key=args["api_key"],
|
api_key=args["api_key"],
|
||||||
@ -99,10 +96,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def get(self, id):
|
def get(self, id):
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
api_based_extension_id = str(id)
|
api_based_extension_id = str(id)
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
@ -125,17 +120,17 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def post(self, id):
|
def post(self, id):
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
api_based_extension_id = str(id)
|
api_based_extension_id = str(id)
|
||||||
tenant_id = current_user.current_tenant_id
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
.add_argument("name", type=str, required=True, location="json")
|
||||||
parser.add_argument("api_key", type=str, required=True, location="json")
|
.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||||
|
.add_argument("api_key", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
extension_data_from_db.name = args["name"]
|
extension_data_from_db.name = args["name"]
|
||||||
@ -154,12 +149,10 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, id):
|
def delete(self, id):
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
api_based_extension_id = str(id)
|
api_based_extension_id = str(id)
|
||||||
tenant_id = current_user.current_tenant_id
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
APIBasedExtensionService.delete(extension_data_from_db)
|
APIBasedExtensionService.delete(extension_data_from_db)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
|
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
from . import api, console_ns
|
from . import api, console_ns
|
||||||
@ -23,9 +22,9 @@ class FeatureApi(Resource):
|
|||||||
@cloud_utm_record
|
@cloud_utm_record
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Get feature configuration for current tenant"""
|
"""Get feature configuration for current tenant"""
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
return FeatureService.get_features(current_tenant_id).model_dump()
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/system-features")
|
@console_ns.route("/system-features")
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with
|
from flask_restx import Resource, marshal_with
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -22,8 +21,7 @@ from controllers.console.wraps import (
|
|||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import file_fields, upload_config_fields
|
from fields.file_fields import file_fields, upload_config_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account
|
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
from . import console_ns
|
from . import console_ns
|
||||||
@ -53,6 +51,7 @@ class FileApi(Resource):
|
|||||||
@marshal_with(file_fields)
|
@marshal_with(file_fields)
|
||||||
@cloud_edition_billing_resource_check("documents")
|
@cloud_edition_billing_resource_check("documents")
|
||||||
def post(self):
|
def post(self):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
source_str = request.form.get("source")
|
source_str = request.form.get("source")
|
||||||
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
|
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
|
||||||
|
|
||||||
@ -65,16 +64,12 @@ class FileApi(Resource):
|
|||||||
|
|
||||||
if not file.filename:
|
if not file.filename:
|
||||||
raise FilenameNotExistsError
|
raise FilenameNotExistsError
|
||||||
|
|
||||||
if source == "datasets" and not current_user.is_dataset_editor:
|
if source == "datasets" and not current_user.is_dataset_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if source not in ("datasets", None):
|
if source not in ("datasets", None):
|
||||||
source = None
|
source = None
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
upload_file = FileService(db.engine).upload_file(
|
upload_file = FileService(db.engine).upload_file(
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
@ -108,4 +103,4 @@ class FileSupportTypeApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
|
return {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}
|
||||||
|
|||||||
@ -57,8 +57,7 @@ class InitValidateAPI(Resource):
|
|||||||
if tenant_count > 0:
|
if tenant_count > 0:
|
||||||
raise AlreadySetupError()
|
raise AlreadySetupError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json")
|
||||||
parser.add_argument("password", type=StrLen(30), required=True, location="json")
|
|
||||||
input_password = parser.parse_args()["password"]
|
input_password = parser.parse_args()["password"]
|
||||||
|
|
||||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||||
|
|||||||
@ -14,8 +14,7 @@ from core.file import helpers as file_helpers
|
|||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models.account import Account
|
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
from . import console_ns
|
from . import console_ns
|
||||||
@ -41,8 +40,7 @@ class RemoteFileInfoApi(Resource):
|
|||||||
class RemoteFileUploadApi(Resource):
|
class RemoteFileUploadApi(Resource):
|
||||||
@marshal_with(file_fields_with_signed_url)
|
@marshal_with(file_fields_with_signed_url)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
|
||||||
parser.add_argument("url", type=str, required=True, help="URL is required")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
url = args["url"]
|
url = args["url"]
|
||||||
@ -64,8 +62,7 @@ class RemoteFileUploadApi(Resource):
|
|||||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert isinstance(current_user, Account)
|
user, _ = current_account_with_tenant()
|
||||||
user = current_user
|
|
||||||
upload_file = FileService(db.engine).upload_file(
|
upload_file = FileService(db.engine).upload_file(
|
||||||
filename=file_info.filename,
|
filename=file_info.filename,
|
||||||
content=content,
|
content=content,
|
||||||
|
|||||||
@ -69,10 +69,12 @@ class SetupApi(Resource):
|
|||||||
if not get_init_validate_status():
|
if not get_init_validate_status():
|
||||||
raise NotInitValidateError()
|
raise NotInitValidateError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=StrLen(30), required=True, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
.add_argument("name", type=StrLen(30), required=True, location="json")
|
||||||
|
.add_argument("password", type=valid_password, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# setup
|
# setup
|
||||||
|
|||||||
@ -5,8 +5,7 @@ from werkzeug.exceptions import Forbidden
|
|||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from fields.tag_fields import dataset_tag_fields
|
from fields.tag_fields import dataset_tag_fields
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from models.model import Tag
|
from models.model import Tag
|
||||||
from services.tag_service import TagService
|
from services.tag_service import TagService
|
||||||
|
|
||||||
@ -24,11 +23,10 @@ class TagListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(dataset_tag_fields)
|
@marshal_with(dataset_tag_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
tag_type = request.args.get("type", type=str, default="")
|
tag_type = request.args.get("type", type=str, default="")
|
||||||
keyword = request.args.get("keyword", default=None, type=str)
|
keyword = request.args.get("keyword", default=None, type=str)
|
||||||
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
tags = TagService.get_tags(tag_type, current_tenant_id, keyword)
|
||||||
|
|
||||||
return tags, 200
|
return tags, 200
|
||||||
|
|
||||||
@ -36,18 +34,23 @@ class TagListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
assert isinstance(current_user, Account)
|
current_user, _ = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
.add_argument(
|
||||||
)
|
"name",
|
||||||
parser.add_argument(
|
nullable=False,
|
||||||
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
required=True,
|
||||||
|
help="Name must be between 1 to 50 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
.add_argument(
|
||||||
|
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
tag = TagService.save_tags(args)
|
tag = TagService.save_tags(args)
|
||||||
@ -63,15 +66,13 @@ class TagUpdateDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, tag_id):
|
def patch(self, tag_id):
|
||||||
assert isinstance(current_user, Account)
|
current_user, _ = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
tag_id = str(tag_id)
|
tag_id = str(tag_id)
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument(
|
|
||||||
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -87,8 +88,7 @@ class TagUpdateDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, tag_id):
|
def delete(self, tag_id):
|
||||||
assert isinstance(current_user, Account)
|
current_user, _ = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
tag_id = str(tag_id)
|
tag_id = str(tag_id)
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.has_edit_permission:
|
||||||
@ -105,21 +105,22 @@ class TagBindingCreateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
assert isinstance(current_user, Account)
|
current_user, _ = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
|
.add_argument(
|
||||||
)
|
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
|
||||||
parser.add_argument(
|
)
|
||||||
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
|
.add_argument(
|
||||||
)
|
"target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required."
|
||||||
parser.add_argument(
|
)
|
||||||
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
.add_argument(
|
||||||
|
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
TagService.save_tag_binding(args)
|
TagService.save_tag_binding(args)
|
||||||
@ -133,17 +134,18 @@ class TagBindingDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
assert isinstance(current_user, Account)
|
current_user, _ = current_account_with_tenant()
|
||||||
assert current_user.current_tenant_id is not None
|
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||||
parser.add_argument(
|
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||||
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
.add_argument(
|
||||||
|
"type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
TagService.delete_tag_binding(args)
|
TagService.delete_tag_binding(args)
|
||||||
|
|||||||
@ -37,8 +37,7 @@ class VersionApi(Resource):
|
|||||||
)
|
)
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Check for application version updates"""
|
"""Check for application version updates"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("current_version", type=str, required=True, location="args")
|
||||||
parser.add_argument("current_version", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||||
|
|
||||||
|
|||||||
@ -2,11 +2,11 @@ from collections.abc import Callable
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from libs.login import current_account_with_tenant
|
||||||
from models.account import TenantPluginPermission
|
from models.account import TenantPluginPermission
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
@ -20,8 +20,9 @@ def plugin_permission_required(
|
|||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
user = current_user
|
user = current_user
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
permission = (
|
permission = (
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from datetime import datetime
|
|||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -37,9 +36,8 @@ from extensions.ext_database import db
|
|||||||
from fields.member_fields import account_fields
|
from fields.member_fields import account_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import TimestampField, email, extract_remote_ip, timezone
|
from libs.helper import TimestampField, email, extract_remote_ip, timezone
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import AccountIntegrate, InvitationCode
|
from models import Account, AccountIntegrate, InvitationCode
|
||||||
from models.account import Account
|
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||||
@ -50,9 +48,7 @@ class AccountInitApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
account, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
|
||||||
|
|
||||||
if account.status == "active":
|
if account.status == "active":
|
||||||
raise AccountAlreadyInitedError()
|
raise AccountAlreadyInitedError()
|
||||||
@ -61,9 +57,9 @@ class AccountInitApi(Resource):
|
|||||||
|
|
||||||
if dify_config.EDITION == "CLOUD":
|
if dify_config.EDITION == "CLOUD":
|
||||||
parser.add_argument("invitation_code", type=str, location="json")
|
parser.add_argument("invitation_code", type=str, location="json")
|
||||||
|
parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
|
||||||
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
|
"timezone", type=timezone, required=True, location="json"
|
||||||
parser.add_argument("timezone", type=timezone, required=True, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if dify_config.EDITION == "CLOUD":
|
if dify_config.EDITION == "CLOUD":
|
||||||
@ -106,8 +102,7 @@ class AccountProfileApi(Resource):
|
|||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
@ -118,10 +113,8 @@ class AccountNameApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate account name length
|
# Validate account name length
|
||||||
@ -140,10 +133,8 @@ class AccountAvatarApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
parser = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("avatar", type=str, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
|
updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
|
||||||
@ -158,10 +149,10 @@ class AccountInterfaceLanguageApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser = reqparse.RequestParser()
|
"interface_language", type=supported_language, required=True, location="json"
|
||||||
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
|
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
|
||||||
@ -176,10 +167,10 @@ class AccountInterfaceThemeApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser = reqparse.RequestParser()
|
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
|
||||||
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
|
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
|
||||||
@ -194,10 +185,8 @@ class AccountTimezoneApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
parser = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("timezone", type=str, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
|
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
|
||||||
@ -216,12 +205,13 @@ class AccountPasswordApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
parser = (
|
||||||
parser = reqparse.RequestParser()
|
reqparse.RequestParser()
|
||||||
parser.add_argument("password", type=str, required=False, location="json")
|
.add_argument("password", type=str, required=False, location="json")
|
||||||
parser.add_argument("new_password", type=str, required=True, location="json")
|
.add_argument("new_password", type=str, required=True, location="json")
|
||||||
parser.add_argument("repeat_new_password", type=str, required=True, location="json")
|
.add_argument("repeat_new_password", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args["new_password"] != args["repeat_new_password"]:
|
if args["new_password"] != args["repeat_new_password"]:
|
||||||
@ -253,9 +243,7 @@ class AccountIntegrateApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(integrate_list_fields)
|
@marshal_with(integrate_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
account, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
|
||||||
|
|
||||||
account_integrates = db.session.scalars(
|
account_integrates = db.session.scalars(
|
||||||
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
|
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
|
||||||
@ -298,9 +286,7 @@ class AccountDeleteVerifyApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
account, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
|
||||||
|
|
||||||
token, code = AccountService.generate_account_deletion_verification_code(account)
|
token, code = AccountService.generate_account_deletion_verification_code(account)
|
||||||
AccountService.send_account_deletion_verification_email(account, code)
|
AccountService.send_account_deletion_verification_email(account, code)
|
||||||
@ -314,13 +300,13 @@ class AccountDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
account, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("token", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("code", type=str, required=True, location="json")
|
.add_argument("token", type=str, required=True, location="json")
|
||||||
|
.add_argument("code", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
|
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
|
||||||
@ -335,9 +321,11 @@ class AccountDeleteApi(Resource):
|
|||||||
class AccountDeleteUpdateFeedbackApi(Resource):
|
class AccountDeleteUpdateFeedbackApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("feedback", type=str, required=True, location="json")
|
.add_argument("email", type=str, required=True, location="json")
|
||||||
|
.add_argument("feedback", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
|
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
|
||||||
@ -358,9 +346,7 @@ class EducationVerifyApi(Resource):
|
|||||||
@cloud_edition_billing_enabled
|
@cloud_edition_billing_enabled
|
||||||
@marshal_with(verify_fields)
|
@marshal_with(verify_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
account, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
|
||||||
|
|
||||||
return BillingService.EducationIdentity.verify(account.id, account.email)
|
return BillingService.EducationIdentity.verify(account.id, account.email)
|
||||||
|
|
||||||
@ -380,14 +366,14 @@ class EducationApi(Resource):
|
|||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
@cloud_edition_billing_enabled
|
@cloud_edition_billing_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
account, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("token", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("institution", type=str, required=True, location="json")
|
.add_argument("token", type=str, required=True, location="json")
|
||||||
parser.add_argument("role", type=str, required=True, location="json")
|
.add_argument("institution", type=str, required=True, location="json")
|
||||||
|
.add_argument("role", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
|
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
|
||||||
@ -399,9 +385,7 @@ class EducationApi(Resource):
|
|||||||
@cloud_edition_billing_enabled
|
@cloud_edition_billing_enabled
|
||||||
@marshal_with(status_fields)
|
@marshal_with(status_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
account, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
account = current_user
|
|
||||||
|
|
||||||
res = BillingService.EducationIdentity.status(account.id)
|
res = BillingService.EducationIdentity.status(account.id)
|
||||||
# convert expire_at to UTC timestamp from isoformat
|
# convert expire_at to UTC timestamp from isoformat
|
||||||
@ -425,10 +409,12 @@ class EducationAutoCompleteApi(Resource):
|
|||||||
@cloud_edition_billing_enabled
|
@cloud_edition_billing_enabled
|
||||||
@marshal_with(data_fields)
|
@marshal_with(data_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("keywords", type=str, required=True, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("page", type=int, required=False, location="args", default=0)
|
.add_argument("keywords", type=str, required=True, location="args")
|
||||||
parser.add_argument("limit", type=int, required=False, location="args", default=20)
|
.add_argument("page", type=int, required=False, location="args", default=0)
|
||||||
|
.add_argument("limit", type=int, required=False, location="args", default=20)
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
|
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
|
||||||
@ -441,11 +427,14 @@ class ChangeEmailSendEmailApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
parser = (
|
||||||
parser.add_argument("language", type=str, required=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("phase", type=str, required=False, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
parser.add_argument("token", type=str, required=False, location="json")
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
|
.add_argument("phase", type=str, required=False, location="json")
|
||||||
|
.add_argument("token", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
@ -467,8 +456,6 @@ class ChangeEmailSendEmailApi(Resource):
|
|||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
user_email = reset_data.get("email", "")
|
user_email = reset_data.get("email", "")
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if user_email != current_user.email:
|
if user_email != current_user.email:
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
else:
|
else:
|
||||||
@ -490,10 +477,12 @@ class ChangeEmailCheckApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("code", type=str, required=True, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
.add_argument("code", type=str, required=True, location="json")
|
||||||
|
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args["email"]
|
||||||
@ -533,9 +522,11 @@ class ChangeEmailResetApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_fields)
|
@marshal_with(account_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("new_email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
.add_argument("new_email", type=email, required=True, location="json")
|
||||||
|
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if AccountService.is_account_in_freeze(args["new_email"]):
|
if AccountService.is_account_in_freeze(args["new_email"]):
|
||||||
@ -551,8 +542,7 @@ class ChangeEmailResetApi(Resource):
|
|||||||
AccountService.revoke_change_email_token(args["token"])
|
AccountService.revoke_change_email_token(args["token"])
|
||||||
|
|
||||||
old_email = reset_data.get("old_email", "")
|
old_email = reset_data.get("old_email", "")
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if current_user.email != old_email:
|
if current_user.email != old_email:
|
||||||
raise AccountNotFound()
|
raise AccountNotFound()
|
||||||
|
|
||||||
@ -569,8 +559,7 @@ class ChangeEmailResetApi(Resource):
|
|||||||
class CheckEmailUnique(Resource):
|
class CheckEmailUnique(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if AccountService.is_account_in_freeze(args["email"]):
|
if AccountService.is_account_in_freeze(args["email"]):
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|||||||
@ -3,8 +3,7 @@ from flask_restx import Resource, fields
|
|||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from services.agent_service import AgentService
|
from services.agent_service import AgentService
|
||||||
|
|
||||||
|
|
||||||
@ -21,12 +20,11 @@ class AgentProviderListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
user = current_user
|
user = current_user
|
||||||
assert user.current_tenant_id is not None
|
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
||||||
|
|
||||||
@ -45,9 +43,5 @@ class AgentProviderApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_name: str):
|
def get(self, provider_name: str):
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
user = current_user
|
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))
|
||||||
assert user.current_tenant_id is not None
|
|
||||||
user_id = user.id
|
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
|
||||||
|
|||||||
@ -5,18 +5,10 @@ from controllers.console import api, console_ns
|
|||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from services.plugin.endpoint_service import EndpointService
|
from services.plugin.endpoint_service import EndpointService
|
||||||
|
|
||||||
|
|
||||||
def _current_account_with_tenant() -> tuple[Account, str]:
|
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
assert tenant_id is not None
|
|
||||||
return current_user, tenant_id
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/endpoints/create")
|
@console_ns.route("/workspaces/current/endpoints/create")
|
||||||
class EndpointCreateApi(Resource):
|
class EndpointCreateApi(Resource):
|
||||||
@api.doc("create_endpoint")
|
@api.doc("create_endpoint")
|
||||||
@ -41,14 +33,16 @@ class EndpointCreateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("plugin_unique_identifier", type=str, required=True)
|
reqparse.RequestParser()
|
||||||
parser.add_argument("settings", type=dict, required=True)
|
.add_argument("plugin_unique_identifier", type=str, required=True)
|
||||||
parser.add_argument("name", type=str, required=True)
|
.add_argument("settings", type=dict, required=True)
|
||||||
|
.add_argument("name", type=str, required=True)
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
plugin_unique_identifier = args["plugin_unique_identifier"]
|
plugin_unique_identifier = args["plugin_unique_identifier"]
|
||||||
@ -87,11 +81,13 @@ class EndpointListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("page", type=int, required=True, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("page_size", type=int, required=True, location="args")
|
.add_argument("page", type=int, required=True, location="args")
|
||||||
|
.add_argument("page_size", type=int, required=True, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
page = args["page"]
|
page = args["page"]
|
||||||
@ -130,12 +126,14 @@ class EndpointListForSinglePluginApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("page", type=int, required=True, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("page_size", type=int, required=True, location="args")
|
.add_argument("page", type=int, required=True, location="args")
|
||||||
parser.add_argument("plugin_id", type=str, required=True, location="args")
|
.add_argument("page_size", type=int, required=True, location="args")
|
||||||
|
.add_argument("plugin_id", type=str, required=True, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
page = args["page"]
|
page = args["page"]
|
||||||
@ -172,10 +170,9 @@ class EndpointDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||||
parser.add_argument("endpoint_id", type=str, required=True)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
@ -212,12 +209,14 @@ class EndpointUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("endpoint_id", type=str, required=True)
|
reqparse.RequestParser()
|
||||||
parser.add_argument("settings", type=dict, required=True)
|
.add_argument("endpoint_id", type=str, required=True)
|
||||||
parser.add_argument("name", type=str, required=True)
|
.add_argument("settings", type=dict, required=True)
|
||||||
|
.add_argument("name", type=str, required=True)
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
endpoint_id = args["endpoint_id"]
|
endpoint_id = args["endpoint_id"]
|
||||||
@ -255,10 +254,9 @@ class EndpointEnableApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||||
parser.add_argument("endpoint_id", type=str, required=True)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
endpoint_id = args["endpoint_id"]
|
endpoint_id = args["endpoint_id"]
|
||||||
@ -288,10 +286,9 @@ class EndpointDisableApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user, tenant_id = _current_account_with_tenant()
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
|
||||||
parser.add_argument("endpoint_id", type=str, required=True)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
endpoint_id = args["endpoint_id"]
|
endpoint_id = args["endpoint_id"]
|
||||||
|
|||||||
@ -5,8 +5,8 @@ from controllers.console import console_ns
|
|||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account, TenantAccountRole
|
from models import TenantAccountRole
|
||||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||||
|
|
||||||
|
|
||||||
@ -18,24 +18,25 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
assert tenant_id is not None
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
"model_type",
|
.add_argument(
|
||||||
type=str,
|
"model_type",
|
||||||
required=True,
|
type=str,
|
||||||
nullable=False,
|
required=True,
|
||||||
choices=[mt.value for mt in ModelType],
|
nullable=False,
|
||||||
location="json",
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# validate model load balancing credentials
|
# validate model load balancing credentials
|
||||||
@ -72,24 +73,25 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str, config_id: str):
|
def post(self, provider: str, config_id: str):
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
assert tenant_id is not None
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
"model_type",
|
.add_argument(
|
||||||
type=str,
|
"model_type",
|
||||||
required=True,
|
type=str,
|
||||||
nullable=False,
|
required=True,
|
||||||
choices=[mt.value for mt in ModelType],
|
nullable=False,
|
||||||
location="json",
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# validate model load balancing config credentials
|
# validate model load balancing config credentials
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from controllers.console.wraps import (
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.member_fields import account_with_role_list_fields
|
from fields.member_fields import account_with_role_list_fields
|
||||||
from libs.helper import extract_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account, TenantAccountRole
|
from models.account import Account, TenantAccountRole
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
from services.errors.account import AccountAlreadyInTenantError
|
from services.errors.account import AccountAlreadyInTenantError
|
||||||
@ -41,8 +41,7 @@ class MemberListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_with_role_list_fields)
|
@marshal_with(account_with_role_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||||
@ -58,10 +57,12 @@ class MemberInviteEmailApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("members")
|
@cloud_edition_billing_resource_check("members")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("emails", type=list, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("role", type=str, required=True, default="admin", location="json")
|
.add_argument("emails", type=list, required=True, location="json")
|
||||||
parser.add_argument("language", type=str, required=False, location="json")
|
.add_argument("role", type=str, required=True, default="admin", location="json")
|
||||||
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
invitee_emails = args["emails"]
|
invitee_emails = args["emails"]
|
||||||
@ -69,9 +70,7 @@ class MemberInviteEmailApi(Resource):
|
|||||||
interface_language = args["language"]
|
interface_language = args["language"]
|
||||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
inviter = current_user
|
inviter = current_user
|
||||||
if not inviter.current_tenant:
|
if not inviter.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
@ -120,8 +119,7 @@ class MemberCancelInviteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, member_id):
|
def delete(self, member_id):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
||||||
@ -153,16 +151,13 @@ class MemberUpdateRoleApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def put(self, member_id):
|
def put(self, member_id):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
|
||||||
parser.add_argument("role", type=str, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
new_role = args["role"]
|
new_role = args["role"]
|
||||||
|
|
||||||
if not TenantAccountRole.is_valid_role(new_role):
|
if not TenantAccountRole.is_valid_role(new_role):
|
||||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
member = db.session.get(Account, str(member_id))
|
member = db.session.get(Account, str(member_id))
|
||||||
@ -189,8 +184,7 @@ class DatasetOperatorMemberListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(account_with_role_list_fields)
|
@marshal_with(account_with_role_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||||
@ -206,16 +200,13 @@ class SendOwnerTransferEmailApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@is_allow_transfer_owner
|
@is_allow_transfer_owner
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
|
||||||
parser.add_argument("language", type=str, required=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
if AccountService.is_email_send_ip_limit(ip_address):
|
if AccountService.is_email_send_ip_limit(ip_address):
|
||||||
raise EmailSendIpLimitError()
|
raise EmailSendIpLimitError()
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
# check if the current user is the owner of the workspace
|
# check if the current user is the owner of the workspace
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||||
@ -245,13 +236,14 @@ class OwnerTransferCheckApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@is_allow_transfer_owner
|
@is_allow_transfer_owner
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("code", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
.add_argument("code", type=str, required=True, location="json")
|
||||||
|
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# check if the current user is the owner of the workspace
|
# check if the current user is the owner of the workspace
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||||
@ -291,13 +283,13 @@ class OwnerTransfer(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@is_allow_transfer_owner
|
@is_allow_transfer_owner
|
||||||
def post(self, member_id):
|
def post(self, member_id):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
"token", type=str, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# check if the current user is the owner of the workspace
|
# check if the current user is the owner of the workspace
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.current_tenant:
|
if not current_user.current_tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
|
|
||||||
from flask import send_file
|
from flask import send_file
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -11,8 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from libs.helper import StrLen, uuid_value
|
from libs.helper import StrLen, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.model_provider_service import ModelProviderService
|
from services.model_provider_service import ModelProviderService
|
||||||
|
|
||||||
@ -23,14 +21,10 @@ class ModelProviderListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
tenant_id = current_tenant_id
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument(
|
|
||||||
"model_type",
|
"model_type",
|
||||||
type=str,
|
type=str,
|
||||||
required=False,
|
required=False,
|
||||||
@ -52,14 +46,12 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
tenant_id = current_tenant_id
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
# if credential_id is not provided, return current used credential
|
# if credential_id is not provided, return current used credential
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
@ -73,23 +65,22 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
try:
|
try:
|
||||||
model_provider_service.create_provider_credential(
|
model_provider_service.create_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
credentials=args["credentials"],
|
credentials=args["credentials"],
|
||||||
credential_name=args["name"],
|
credential_name=args["name"],
|
||||||
@ -103,24 +94,23 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def put(self, provider: str):
|
def put(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
try:
|
try:
|
||||||
model_provider_service.update_provider_credential(
|
model_provider_service.update_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
credentials=args["credentials"],
|
credentials=args["credentials"],
|
||||||
credential_id=args["credential_id"],
|
credential_id=args["credential_id"],
|
||||||
@ -135,19 +125,17 @@ class ModelProviderCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider: str):
|
def delete(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_provider_service.remove_provider_credential(
|
model_provider_service.remove_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
@ -159,19 +147,17 @@ class ModelProviderCredentialSwitchApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
service = ModelProviderService()
|
service = ModelProviderService()
|
||||||
service.switch_active_provider_credential(
|
service.switch_active_provider_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
credential_id=args["credential_id"],
|
credential_id=args["credential_id"],
|
||||||
)
|
)
|
||||||
@ -184,15 +170,13 @@ class ModelProviderValidateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser = reqparse.RequestParser()
|
"credentials", type=dict, required=True, nullable=False, location="json"
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
tenant_id = current_tenant_id
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
@ -240,17 +224,13 @@ class PreferredProviderTypeUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
tenant_id = current_tenant_id
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument(
|
|
||||||
"preferred_provider_type",
|
"preferred_provider_type",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
@ -276,14 +256,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
|||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
if provider != "anthropic":
|
if provider != "anthropic":
|
||||||
raise ValueError(f"provider name {provider} is invalid")
|
raise ValueError(f"provider name {provider} is invalid")
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
data = BillingService.get_model_provider_payment_link(
|
data = BillingService.get_model_provider_payment_link(
|
||||||
provider_name=provider,
|
provider_name=provider,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
account_id=current_user.id,
|
account_id=current_user.id,
|
||||||
prefilled_email=current_user.email,
|
prefilled_email=current_user.email,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -10,7 +9,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from libs.helper import StrLen, uuid_value
|
from libs.helper import StrLen, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||||
from services.model_provider_service import ModelProviderService
|
from services.model_provider_service import ModelProviderService
|
||||||
|
|
||||||
@ -23,8 +22,9 @@ class DefaultModelApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
_, tenant_id = current_account_with_tenant()
|
||||||
parser.add_argument(
|
|
||||||
|
parser = reqparse.RequestParser().add_argument(
|
||||||
"model_type",
|
"model_type",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
@ -34,8 +34,6 @@ class DefaultModelApi(Resource):
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
||||||
tenant_id=tenant_id, model_type=args["model_type"]
|
tenant_id=tenant_id, model_type=args["model_type"]
|
||||||
@ -47,15 +45,15 @@ class DefaultModelApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
|
current_user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
|
"model_settings", type=list, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_settings = args["model_settings"]
|
model_settings = args["model_settings"]
|
||||||
for model_setting in model_settings:
|
for model_setting in model_settings:
|
||||||
@ -92,7 +90,7 @@ class ModelProviderModelApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
|
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
|
||||||
@ -104,24 +102,26 @@ class ModelProviderModelApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
# To save the model's load balance configs
|
# To save the model's load balance configs
|
||||||
|
current_user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
parser = (
|
||||||
|
reqparse.RequestParser()
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
.add_argument(
|
||||||
parser.add_argument(
|
"model_type",
|
||||||
"model_type",
|
type=str,
|
||||||
type=str,
|
required=True,
|
||||||
required=True,
|
nullable=False,
|
||||||
nullable=False,
|
choices=[mt.value for mt in ModelType],
|
||||||
choices=[mt.value for mt in ModelType],
|
location="json",
|
||||||
location="json",
|
)
|
||||||
|
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
|
|
||||||
parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
|
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.get("config_from", "") == "custom-model":
|
if args.get("config_from", "") == "custom-model":
|
||||||
@ -129,7 +129,7 @@ class ModelProviderModelApi(Resource):
|
|||||||
raise ValueError("credential_id is required when configuring a custom-model")
|
raise ValueError("credential_id is required when configuring a custom-model")
|
||||||
service = ModelProviderService()
|
service = ModelProviderService()
|
||||||
service.switch_active_custom_model_credential(
|
service.switch_active_custom_model_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model_type=args["model_type"],
|
model_type=args["model_type"],
|
||||||
model=args["model"],
|
model=args["model"],
|
||||||
@ -164,20 +164,22 @@ class ModelProviderModelApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider: str):
|
def delete(self, provider: str):
|
||||||
|
current_user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
parser = (
|
||||||
|
reqparse.RequestParser()
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
.add_argument(
|
||||||
parser.add_argument(
|
"model_type",
|
||||||
"model_type",
|
type=str,
|
||||||
type=str,
|
required=True,
|
||||||
required=True,
|
nullable=False,
|
||||||
nullable=False,
|
choices=[mt.value for mt in ModelType],
|
||||||
choices=[mt.value for mt in ModelType],
|
location="json",
|
||||||
location="json",
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -195,20 +197,22 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||||
"model_type",
|
.add_argument(
|
||||||
type=str,
|
"model_type",
|
||||||
required=True,
|
type=str,
|
||||||
nullable=False,
|
required=True,
|
||||||
choices=[mt.value for mt in ModelType],
|
nullable=False,
|
||||||
location="args",
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="args",
|
||||||
|
)
|
||||||
|
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
|
||||||
|
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||||
)
|
)
|
||||||
parser.add_argument("config_from", type=str, required=False, nullable=True, location="args")
|
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
@ -257,24 +261,27 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
|
current_user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
"model_type",
|
.add_argument(
|
||||||
type=str,
|
"model_type",
|
||||||
required=True,
|
type=str,
|
||||||
nullable=False,
|
required=True,
|
||||||
choices=[mt.value for mt in ModelType],
|
nullable=False,
|
||||||
location="json",
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||||
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -301,29 +308,33 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def put(self, provider: str):
|
def put(self, provider: str):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
"model_type",
|
.add_argument(
|
||||||
type=str,
|
"model_type",
|
||||||
required=True,
|
type=str,
|
||||||
nullable=False,
|
required=True,
|
||||||
choices=[mt.value for mt in ModelType],
|
nullable=False,
|
||||||
location="json",
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_provider_service.update_model_credential(
|
model_provider_service.update_model_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model_type=args["model_type"],
|
model_type=args["model_type"],
|
||||||
model=args["model"],
|
model=args["model"],
|
||||||
@ -340,24 +351,28 @@ class ModelProviderModelCredentialApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider: str):
|
def delete(self, provider: str):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
"model_type",
|
.add_argument(
|
||||||
type=str,
|
"model_type",
|
||||||
required=True,
|
type=str,
|
||||||
nullable=False,
|
required=True,
|
||||||
choices=[mt.value for mt in ModelType],
|
nullable=False,
|
||||||
location="json",
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
model_provider_service.remove_model_credential(
|
model_provider_service.remove_model_credential(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model_type=args["model_type"],
|
model_type=args["model_type"],
|
||||||
model=args["model"],
|
model=args["model"],
|
||||||
@ -373,24 +388,28 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
"model_type",
|
.add_argument(
|
||||||
type=str,
|
"model_type",
|
||||||
required=True,
|
type=str,
|
||||||
nullable=False,
|
required=True,
|
||||||
choices=[mt.value for mt in ModelType],
|
nullable=False,
|
||||||
location="json",
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
service = ModelProviderService()
|
service = ModelProviderService()
|
||||||
service.add_model_credential_to_model_list(
|
service.add_model_credential_to_model_list(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model_type=args["model_type"],
|
model_type=args["model_type"],
|
||||||
model=args["model"],
|
model=args["model"],
|
||||||
@ -407,17 +426,19 @@ class ModelProviderModelEnableApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, provider: str):
|
def patch(self, provider: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
"model_type",
|
.add_argument(
|
||||||
type=str,
|
"model_type",
|
||||||
required=True,
|
type=str,
|
||||||
nullable=False,
|
required=True,
|
||||||
choices=[mt.value for mt in ModelType],
|
nullable=False,
|
||||||
location="json",
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -437,17 +458,19 @@ class ModelProviderModelDisableApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, provider: str):
|
def patch(self, provider: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
"model_type",
|
.add_argument(
|
||||||
type=str,
|
"model_type",
|
||||||
required=True,
|
type=str,
|
||||||
nullable=False,
|
required=True,
|
||||||
choices=[mt.value for mt in ModelType],
|
nullable=False,
|
||||||
location="json",
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -465,19 +488,21 @@ class ModelProviderModelValidateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider: str):
|
def post(self, provider: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||||
"model_type",
|
.add_argument(
|
||||||
type=str,
|
"model_type",
|
||||||
required=True,
|
type=str,
|
||||||
nullable=False,
|
required=True,
|
||||||
choices=[mt.value for mt in ModelType],
|
nullable=False,
|
||||||
location="json",
|
choices=[mt.value for mt in ModelType],
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
@ -511,11 +536,11 @@ class ModelProviderModelParameterRuleApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
|
"model", type=str, required=True, nullable=False, location="args"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, tenant_id = current_account_with_tenant()
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
parameter_rules = model_provider_service.get_model_parameter_rules(
|
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||||
@ -531,8 +556,7 @@ class ModelProviderAvailableModelApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, model_type):
|
def get(self, model_type):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
model_provider_service = ModelProviderService()
|
model_provider_service = ModelProviderService()
|
||||||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
|
|
||||||
from flask import request, send_file
|
from flask import request, send_file
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -11,7 +10,7 @@ from controllers.console.workspace import plugin_permission_required
|
|||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||||
from services.plugin.plugin_parameter_service import PluginParameterService
|
from services.plugin.plugin_parameter_service import PluginParameterService
|
||||||
@ -26,7 +25,7 @@ class PluginDebuggingKeyApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(debug_required=True)
|
@plugin_permission_required(debug_required=True)
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
@ -44,10 +43,12 @@ class PluginListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("page", type=int, required=False, location="args", default=1)
|
reqparse.RequestParser()
|
||||||
parser.add_argument("page_size", type=int, required=False, location="args", default=256)
|
.add_argument("page", type=int, required=False, location="args", default=1)
|
||||||
|
.add_argument("page_size", type=int, required=False, location="args", default=256)
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
try:
|
try:
|
||||||
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
|
plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
|
||||||
@ -63,8 +64,7 @@ class PluginListLatestVersionsApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
req = reqparse.RequestParser()
|
req = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||||
req.add_argument("plugin_ids", type=list, required=True, location="json")
|
|
||||||
args = req.parse_args()
|
args = req.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -81,10 +81,9 @@ class PluginListInstallationsFromIdsApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
|
||||||
parser.add_argument("plugin_ids", type=list, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -99,9 +98,11 @@ class PluginListInstallationsFromIdsApi(Resource):
|
|||||||
class PluginIconApi(Resource):
|
class PluginIconApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def get(self):
|
def get(self):
|
||||||
req = reqparse.RequestParser()
|
req = (
|
||||||
req.add_argument("tenant_id", type=str, required=True, location="args")
|
reqparse.RequestParser()
|
||||||
req.add_argument("filename", type=str, required=True, location="args")
|
.add_argument("tenant_id", type=str, required=True, location="args")
|
||||||
|
.add_argument("filename", type=str, required=True, location="args")
|
||||||
|
)
|
||||||
args = req.parse_args()
|
args = req.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -120,7 +121,7 @@ class PluginUploadFromPkgApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
file = request.files["pkg"]
|
file = request.files["pkg"]
|
||||||
|
|
||||||
@ -144,12 +145,14 @@ class PluginUploadFromGithubApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("repo", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("version", type=str, required=True, location="json")
|
.add_argument("repo", type=str, required=True, location="json")
|
||||||
parser.add_argument("package", type=str, required=True, location="json")
|
.add_argument("version", type=str, required=True, location="json")
|
||||||
|
.add_argument("package", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -167,7 +170,7 @@ class PluginUploadFromBundleApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
file = request.files["bundle"]
|
file = request.files["bundle"]
|
||||||
|
|
||||||
@ -191,10 +194,11 @@ class PluginInstallFromPkgApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# check if all plugin_unique_identifiers are valid string
|
# check if all plugin_unique_identifiers are valid string
|
||||||
@ -217,13 +221,15 @@ class PluginInstallFromGithubApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("repo", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("version", type=str, required=True, location="json")
|
.add_argument("repo", type=str, required=True, location="json")
|
||||||
parser.add_argument("package", type=str, required=True, location="json")
|
.add_argument("version", type=str, required=True, location="json")
|
||||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
|
.add_argument("package", type=str, required=True, location="json")
|
||||||
|
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -247,10 +253,11 @@ class PluginInstallFromMarketplaceApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
"plugin_unique_identifiers", type=list, required=True, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# check if all plugin_unique_identifiers are valid string
|
# check if all plugin_unique_identifiers are valid string
|
||||||
@ -273,10 +280,11 @@ class PluginFetchMarketplacePkgApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -299,10 +307,11 @@ class PluginFetchManifestApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
"plugin_unique_identifier", type=str, required=True, location="args"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -324,11 +333,13 @@ class PluginFetchInstallTasksApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("page", type=int, required=True, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("page_size", type=int, required=True, location="args")
|
.add_argument("page", type=int, required=True, location="args")
|
||||||
|
.add_argument("page_size", type=int, required=True, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -346,7 +357,7 @@ class PluginFetchInstallTaskApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def get(self, task_id: str):
|
def get(self, task_id: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
|
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
|
||||||
@ -361,7 +372,7 @@ class PluginDeleteInstallTaskApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self, task_id: str):
|
def post(self, task_id: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
|
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
|
||||||
@ -376,7 +387,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
|
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
|
||||||
@ -391,7 +402,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self, task_id: str, identifier: str):
|
def post(self, task_id: str, identifier: str):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
|
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
|
||||||
@ -406,11 +417,13 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||||
|
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -430,14 +443,16 @@ class PluginUpgradeFromGithubApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||||
parser.add_argument("repo", type=str, required=True, location="json")
|
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||||
parser.add_argument("version", type=str, required=True, location="json")
|
.add_argument("repo", type=str, required=True, location="json")
|
||||||
parser.add_argument("package", type=str, required=True, location="json")
|
.add_argument("version", type=str, required=True, location="json")
|
||||||
|
.add_argument("package", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -462,11 +477,10 @@ class PluginUninstallApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@plugin_permission_required(install_required=True)
|
@plugin_permission_required(install_required=True)
|
||||||
def post(self):
|
def post(self):
|
||||||
req = reqparse.RequestParser()
|
req = reqparse.RequestParser().add_argument("plugin_installation_id", type=str, required=True, location="json")
|
||||||
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
|
|
||||||
args = req.parse_args()
|
args = req.parse_args()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
|
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
|
||||||
@ -480,19 +494,22 @@ class PluginChangePermissionApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
user = current_user
|
user = current_user
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
req = reqparse.RequestParser()
|
req = (
|
||||||
req.add_argument("install_permission", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
req.add_argument("debug_permission", type=str, required=True, location="json")
|
.add_argument("install_permission", type=str, required=True, location="json")
|
||||||
|
.add_argument("debug_permission", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = req.parse_args()
|
args = req.parse_args()
|
||||||
|
|
||||||
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
||||||
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
||||||
|
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
||||||
|
|
||||||
@ -503,7 +520,7 @@ class PluginFetchPermissionApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
permission = PluginPermissionService.get_permission(tenant_id)
|
permission = PluginPermissionService.get_permission(tenant_id)
|
||||||
if not permission:
|
if not permission:
|
||||||
@ -529,18 +546,20 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
# check if the user is admin or owner
|
# check if the user is admin or owner
|
||||||
|
current_user, tenant_id = current_account_with_tenant()
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
tenant_id = current_user.current_tenant_id
|
|
||||||
user_id = current_user.id
|
user_id = current_user.id
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("plugin_id", type=str, required=True, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("provider", type=str, required=True, location="args")
|
.add_argument("plugin_id", type=str, required=True, location="args")
|
||||||
parser.add_argument("action", type=str, required=True, location="args")
|
.add_argument("provider", type=str, required=True, location="args")
|
||||||
parser.add_argument("parameter", type=str, required=True, location="args")
|
.add_argument("action", type=str, required=True, location="args")
|
||||||
parser.add_argument("provider_type", type=str, required=True, location="args")
|
.add_argument("parameter", type=str, required=True, location="args")
|
||||||
|
.add_argument("provider_type", type=str, required=True, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -565,17 +584,17 @@ class PluginChangePreferencesApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
req = reqparse.RequestParser()
|
req = (
|
||||||
req.add_argument("permission", type=dict, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
req.add_argument("auto_upgrade", type=dict, required=True, location="json")
|
.add_argument("permission", type=dict, required=True, location="json")
|
||||||
|
.add_argument("auto_upgrade", type=dict, required=True, location="json")
|
||||||
|
)
|
||||||
args = req.parse_args()
|
args = req.parse_args()
|
||||||
|
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
permission = args["permission"]
|
permission = args["permission"]
|
||||||
|
|
||||||
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
|
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
|
||||||
@ -621,7 +640,7 @@ class PluginFetchPreferencesApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
permission = PluginPermissionService.get_permission(tenant_id)
|
permission = PluginPermissionService.get_permission(tenant_id)
|
||||||
permission_dict = {
|
permission_dict = {
|
||||||
@ -661,10 +680,9 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
# exclude one single plugin
|
# exclude one single plugin
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
req = reqparse.RequestParser()
|
req = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
|
||||||
req.add_argument("plugin_id", type=str, required=True, location="json")
|
|
||||||
args = req.parse_args()
|
args = req.parse_args()
|
||||||
|
|
||||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
|
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import io
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from flask import make_response, redirect, request, send_file
|
from flask import make_response, redirect, request, send_file
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import (
|
from flask_restx import (
|
||||||
Resource,
|
Resource,
|
||||||
reqparse,
|
reqparse,
|
||||||
@ -24,7 +23,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from core.tools.entities.tool_entities import CredentialType
|
from core.tools.entities.tool_entities import CredentialType
|
||||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.provider_ids import ToolProviderID
|
from models.provider_ids import ToolProviderID
|
||||||
from services.plugin.oauth_service import OAuthProxyService
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||||
@ -53,13 +52,11 @@ class ToolProviderListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
req = reqparse.RequestParser()
|
req = reqparse.RequestParser().add_argument(
|
||||||
req.add_argument(
|
|
||||||
"type",
|
"type",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["builtin", "model", "api", "workflow", "mcp"],
|
choices=["builtin", "model", "api", "workflow", "mcp"],
|
||||||
@ -78,9 +75,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
user = current_user
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.list_builtin_tool_provider_tools(
|
BuiltinToolManageService.list_builtin_tool_provider_tools(
|
||||||
@ -96,9 +91,7 @@ class ToolBuiltinProviderInfoApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
user = current_user
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||||
|
|
||||||
@ -109,13 +102,13 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider):
|
def post(self, provider):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
tenant_id = user.current_tenant_id
|
req = reqparse.RequestParser().add_argument(
|
||||||
req = reqparse.RequestParser()
|
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||||
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
)
|
||||||
args = req.parse_args()
|
args = req.parse_args()
|
||||||
|
|
||||||
return BuiltinToolManageService.delete_builtin_tool_provider(
|
return BuiltinToolManageService.delete_builtin_tool_provider(
|
||||||
@ -131,15 +124,16 @@ class ToolBuiltinProviderAddApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider):
|
def post(self, provider):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
|
||||||
|
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args["type"] not in CredentialType.values():
|
if args["type"] not in CredentialType.values():
|
||||||
@ -161,18 +155,19 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider):
|
def post(self, provider):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -193,7 +188,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||||
@ -218,23 +213,24 @@ class ToolApiProviderAddApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
|
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
|
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
|
.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[])
|
||||||
|
.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -258,14 +254,11 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args")
|
||||||
|
|
||||||
parser.add_argument("url", type=str, required=True, nullable=False, location="args")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -282,14 +275,13 @@ class ToolApiProviderListToolsApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
|
"provider", type=str, required=True, nullable=False, location="args"
|
||||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -308,24 +300,25 @@ class ToolApiProviderUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
|
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
.add_argument("original_provider", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
|
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json")
|
||||||
parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
|
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||||
|
.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -350,17 +343,16 @@ class ToolApiProviderDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
|
"provider", type=str, required=True, nullable=False, location="json"
|
||||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -377,14 +369,13 @@ class ToolApiProviderGetApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
|
"provider", type=str, required=True, nullable=False, location="args"
|
||||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -401,8 +392,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider, credential_type):
|
def get(self, provider, credential_type):
|
||||||
user = current_user
|
_, tenant_id = current_account_with_tenant()
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.list_builtin_provider_credentials_schema(
|
BuiltinToolManageService.list_builtin_provider_credentials_schema(
|
||||||
@ -417,9 +407,9 @@ class ToolApiProviderSchemaApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
|
"schema", type=str, required=True, nullable=False, location="json"
|
||||||
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -434,19 +424,20 @@ class ToolApiProviderPreviousTestApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
|
reqparse.RequestParser()
|
||||||
parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
|
.add_argument("tool_name", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
|
.add_argument("provider_name", type=str, required=False, nullable=False, location="json")
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
|
.add_argument("parameters", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
.add_argument("schema_type", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
return ApiToolManageService.test_api_tool_preview(
|
return ApiToolManageService.test_api_tool_preview(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
args["provider_name"] or "",
|
args["provider_name"] or "",
|
||||||
args["tool_name"],
|
args["tool_name"],
|
||||||
args["credentials"],
|
args["credentials"],
|
||||||
@ -462,23 +453,24 @@ class ToolWorkflowProviderCreateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
reqparser = reqparse.RequestParser()
|
reqparser = (
|
||||||
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument("label", type=str, required=True, nullable=False, location="json")
|
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument("description", type=str, required=True, nullable=False, location="json")
|
.add_argument("label", type=str, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
||||||
|
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
|
|
||||||
args = reqparser.parse_args()
|
args = reqparser.parse_args()
|
||||||
|
|
||||||
@ -502,23 +494,24 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
reqparser = reqparse.RequestParser()
|
reqparser = (
|
||||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument("label", type=str, required=True, nullable=False, location="json")
|
.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument("description", type=str, required=True, nullable=False, location="json")
|
.add_argument("label", type=str, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
.add_argument("description", type=str, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
.add_argument("icon", type=dict, required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json")
|
||||||
reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="")
|
||||||
|
.add_argument("labels", type=list[str], required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
|
|
||||||
args = reqparser.parse_args()
|
args = reqparser.parse_args()
|
||||||
|
|
||||||
@ -545,16 +538,16 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
reqparser = reqparse.RequestParser()
|
reqparser = reqparse.RequestParser().add_argument(
|
||||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
|
|
||||||
args = reqparser.parse_args()
|
args = reqparser.parse_args()
|
||||||
|
|
||||||
@ -571,14 +564,15 @@ class ToolWorkflowProviderGetApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
|
.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||||
|
.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -606,13 +600,13 @@ class ToolWorkflowProviderListToolApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
|
"workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -631,10 +625,9 @@ class ToolBuiltinListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
[
|
[
|
||||||
@ -653,8 +646,7 @@ class ToolApiListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
_, tenant_id = current_account_with_tenant()
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
[
|
[
|
||||||
@ -672,10 +664,9 @@ class ToolWorkflowListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
[
|
[
|
||||||
@ -709,19 +700,18 @@ class ToolPluginOAuthApi(Resource):
|
|||||||
provider_name = tool_provider.provider_name
|
provider_name = tool_provider.provider_name
|
||||||
|
|
||||||
# todo check permission
|
# todo check permission
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
|
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
|
||||||
if oauth_client_params is None:
|
if oauth_client_params is None:
|
||||||
raise Forbidden("no oauth available client config found for this tool provider")
|
raise Forbidden("no oauth available client config found for this tool provider")
|
||||||
|
|
||||||
oauth_handler = OAuthHandler()
|
oauth_handler = OAuthHandler()
|
||||||
context_id = OAuthProxyService.create_proxy_context(
|
context_id = OAuthProxyService.create_proxy_context(
|
||||||
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
|
user_id=user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
|
||||||
)
|
)
|
||||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||||
authorization_url_response = oauth_handler.get_authorization_url(
|
authorization_url_response = oauth_handler.get_authorization_url(
|
||||||
@ -800,11 +790,11 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider):
|
def post(self, provider):
|
||||||
parser = reqparse.RequestParser()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return BuiltinToolManageService.set_default_provider(
|
return BuiltinToolManageService.set_default_provider(
|
||||||
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
|
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -814,18 +804,20 @@ class ToolOAuthCustomClient(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider):
|
def post(self, provider):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not user.is_admin_or_owner:
|
if not user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
return BuiltinToolManageService.save_custom_oauth_client_params(
|
return BuiltinToolManageService.save_custom_oauth_client_params(
|
||||||
tenant_id=user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
client_params=args.get("client_params", {}),
|
client_params=args.get("client_params", {}),
|
||||||
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
|
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
|
||||||
@ -835,20 +827,18 @@ class ToolOAuthCustomClient(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.get_custom_oauth_client_params(
|
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||||
tenant_id=current_user.current_tenant_id, provider=provider
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider):
|
def delete(self, provider):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.delete_custom_oauth_client_params(
|
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||||
tenant_id=current_user.current_tenant_id, provider=provider
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -858,9 +848,10 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
|
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
|
||||||
tenant_id=current_user.current_tenant_id, provider_name=provider
|
tenant_id=current_tenant_id, provider_name=provider
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -871,7 +862,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider):
|
def get(self, provider):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
|
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
|
||||||
@ -887,25 +878,25 @@ class ToolProviderMCPApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||||
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
|
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument(
|
.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
|
||||||
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
|
.add_argument("sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300)
|
||||||
|
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||||
)
|
)
|
||||||
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
user = current_user
|
user, tenant_id = current_account_with_tenant()
|
||||||
if not is_valid_url(args["server_url"]):
|
if not is_valid_url(args["server_url"]):
|
||||||
raise ValueError("Server URL is not valid.")
|
raise ValueError("Server URL is not valid.")
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
MCPToolManageService.create_mcp_provider(
|
MCPToolManageService.create_mcp_provider(
|
||||||
tenant_id=user.current_tenant_id,
|
tenant_id=tenant_id,
|
||||||
server_url=args["server_url"],
|
server_url=args["server_url"],
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
icon=args["icon"],
|
icon=args["icon"],
|
||||||
@ -923,25 +914,28 @@ class ToolProviderMCPApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def put(self):
|
def put(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
|
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
|
.add_argument("timeout", type=float, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json")
|
.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("headers", type=dict, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not is_valid_url(args["server_url"]):
|
if not is_valid_url(args["server_url"]):
|
||||||
if "[__HIDDEN__]" in args["server_url"]:
|
if "[__HIDDEN__]" in args["server_url"]:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise ValueError("Server URL is not valid.")
|
raise ValueError("Server URL is not valid.")
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
MCPToolManageService.update_mcp_provider(
|
MCPToolManageService.update_mcp_provider(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider_id=args["provider_id"],
|
provider_id=args["provider_id"],
|
||||||
server_url=args["server_url"],
|
server_url=args["server_url"],
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
@ -959,10 +953,12 @@ class ToolProviderMCPApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self):
|
def delete(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
"provider_id", type=str, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
@ -972,12 +968,14 @@ class ToolMCPAuthApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
|
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
provider_id = args["provider_id"]
|
provider_id = args["provider_id"]
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||||
if not provider:
|
if not provider:
|
||||||
raise ValueError("provider not found")
|
raise ValueError("provider not found")
|
||||||
@ -1018,8 +1016,8 @@ class ToolMCPDetailApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_id):
|
def get(self, provider_id):
|
||||||
user = current_user
|
_, tenant_id = current_account_with_tenant()
|
||||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
|
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||||
|
|
||||||
|
|
||||||
@ -1029,8 +1027,7 @@ class ToolMCPListAllApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
_, tenant_id = current_account_with_tenant()
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
|
|
||||||
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
|
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
|
||||||
|
|
||||||
@ -1043,7 +1040,7 @@ class ToolMCPUpdateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, provider_id):
|
def get(self, provider_id):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
|
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
@ -1054,9 +1051,11 @@ class ToolMCPUpdateApi(Resource):
|
|||||||
@console_ns.route("/mcp/oauth/callback")
|
@console_ns.route("/mcp/oauth/callback")
|
||||||
class ToolMCPCallbackApi(Resource):
|
class ToolMCPCallbackApi(Resource):
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("code", type=str, required=True, nullable=False, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("state", type=str, required=True, nullable=False, location="args")
|
.add_argument("code", type=str, required=True, nullable=False, location="args")
|
||||||
|
.add_argument("state", type=str, required=True, nullable=False, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
state_key = args["state"]
|
state_key = args["state"]
|
||||||
authorization_code = args["code"]
|
authorization_code = args["code"]
|
||||||
|
|||||||
@ -23,8 +23,8 @@ from controllers.console.wraps import (
|
|||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account, Tenant, TenantStatus
|
from models.account import Tenant, TenantStatus
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
@ -70,8 +70,7 @@ class TenantListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
tenants = TenantService.get_join_tenants(current_user)
|
tenants = TenantService.get_join_tenants(current_user)
|
||||||
tenant_dicts = []
|
tenant_dicts = []
|
||||||
|
|
||||||
@ -85,7 +84,7 @@ class TenantListApi(Resource):
|
|||||||
"status": tenant.status,
|
"status": tenant.status,
|
||||||
"created_at": tenant.created_at,
|
"created_at": tenant.created_at,
|
||||||
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
|
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
|
||||||
"current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False,
|
"current": tenant.id == current_tenant_id if current_tenant_id else False,
|
||||||
}
|
}
|
||||||
|
|
||||||
tenant_dicts.append(tenant_dict)
|
tenant_dicts.append(tenant_dict)
|
||||||
@ -98,9 +97,11 @@ class WorkspaceListApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@admin_required
|
@admin_required
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||||
|
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
stmt = select(Tenant).order_by(Tenant.created_at.desc())
|
stmt = select(Tenant).order_by(Tenant.created_at.desc())
|
||||||
@ -130,8 +131,7 @@ class TenantApi(Resource):
|
|||||||
if request.path == "/info":
|
if request.path == "/info":
|
||||||
logger.warning("Deprecated URL /info was used.")
|
logger.warning("Deprecated URL /info was used.")
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
tenant = current_user.current_tenant
|
tenant = current_user.current_tenant
|
||||||
if not tenant:
|
if not tenant:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
@ -155,10 +155,8 @@ class SwitchWorkspaceApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
parser = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# check if tenant_id is valid, 403 if not
|
# check if tenant_id is valid, 403 if not
|
||||||
@ -181,16 +179,14 @@ class CustomConfigWorkspaceApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("workspace_custom")
|
@cloud_edition_billing_resource_check("workspace_custom")
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
parser = (
|
||||||
parser = reqparse.RequestParser()
|
reqparse.RequestParser()
|
||||||
parser.add_argument("remove_webapp_brand", type=bool, location="json")
|
.add_argument("remove_webapp_brand", type=bool, location="json")
|
||||||
parser.add_argument("replace_webapp_logo", type=str, location="json")
|
.add_argument("replace_webapp_logo", type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||||
if not current_user.current_tenant_id:
|
|
||||||
raise ValueError("No current tenant")
|
|
||||||
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
|
||||||
|
|
||||||
custom_config_dict = {
|
custom_config_dict = {
|
||||||
"remove_webapp_brand": args["remove_webapp_brand"],
|
"remove_webapp_brand": args["remove_webapp_brand"],
|
||||||
@ -212,8 +208,7 @@ class WebappLogoWorkspaceApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("workspace_custom")
|
@cloud_edition_billing_resource_check("workspace_custom")
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
# check file
|
# check file
|
||||||
if "file" not in request.files:
|
if "file" not in request.files:
|
||||||
raise NoFileUploadedError()
|
raise NoFileUploadedError()
|
||||||
@ -253,15 +248,13 @@ class WorkspaceInfoApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
# Change workspace name
|
# Change workspace name
|
||||||
def post(self):
|
def post(self):
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("Invalid user account")
|
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not current_user.current_tenant_id:
|
if not current_tenant_id:
|
||||||
raise ValueError("No current tenant")
|
raise ValueError("No current tenant")
|
||||||
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||||
tenant.name = args["name"]
|
tenant.name = args["name"]
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
|||||||
@ -12,8 +12,8 @@ from configs import dify_config
|
|||||||
from controllers.console.workspace.error import AccountNotInitializedError
|
from controllers.console.workspace.error import AccountNotInitializedError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models.account import Account, AccountStatus
|
from models.account import AccountStatus
|
||||||
from models.dataset import RateLimitLog
|
from models.dataset import RateLimitLog
|
||||||
from models.model import DifySetup
|
from models.model import DifySetup
|
||||||
from services.feature_service import FeatureService, LicenseStatus
|
from services.feature_service import FeatureService, LicenseStatus
|
||||||
@ -25,18 +25,12 @@ P = ParamSpec("P")
|
|||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
def _current_account() -> Account:
|
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
def account_initialization_required(view: Callable[P, R]):
|
def account_initialization_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
# check account initialization
|
# check account initialization
|
||||||
account = _current_account()
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if current_user.status == AccountStatus.UNINITIALIZED:
|
||||||
if account.status == AccountStatus.UNINITIALIZED:
|
|
||||||
raise AccountNotInitializedError()
|
raise AccountNotInitializedError()
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
@ -80,9 +74,8 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
|||||||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
features = FeatureService.get_features(current_tenant_id)
|
||||||
features = FeatureService.get_features(account.current_tenant_id)
|
|
||||||
if not features.billing.enabled:
|
if not features.billing.enabled:
|
||||||
abort(403, "Billing feature is not enabled.")
|
abort(403, "Billing feature is not enabled.")
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
@ -94,10 +87,8 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
features = FeatureService.get_features(current_tenant_id)
|
||||||
tenant_id = account.current_tenant_id
|
|
||||||
features = FeatureService.get_features(tenant_id)
|
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
members = features.members
|
members = features.members
|
||||||
apps = features.apps
|
apps = features.apps
|
||||||
@ -138,9 +129,8 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
|||||||
def interceptor(view: Callable[P, R]):
|
def interceptor(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
features = FeatureService.get_features(current_tenant_id)
|
||||||
features = FeatureService.get_features(account.current_tenant_id)
|
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
if resource == "add_segment":
|
if resource == "add_segment":
|
||||||
if features.billing.subscription.plan == "sandbox":
|
if features.billing.subscription.plan == "sandbox":
|
||||||
@ -163,13 +153,11 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
if resource == "knowledge":
|
if resource == "knowledge":
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_tenant_id)
|
||||||
tenant_id = account.current_tenant_id
|
|
||||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
|
|
||||||
if knowledge_rate_limit.enabled:
|
if knowledge_rate_limit.enabled:
|
||||||
current_time = int(time.time() * 1000)
|
current_time = int(time.time() * 1000)
|
||||||
key = f"rate_limit_{tenant_id}"
|
key = f"rate_limit_{current_tenant_id}"
|
||||||
|
|
||||||
redis_client.zadd(key, {current_time: current_time})
|
redis_client.zadd(key, {current_time: current_time})
|
||||||
|
|
||||||
@ -180,7 +168,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||||||
if request_count > knowledge_rate_limit.limit:
|
if request_count > knowledge_rate_limit.limit:
|
||||||
# add ratelimit record
|
# add ratelimit record
|
||||||
rate_limit_log = RateLimitLog(
|
rate_limit_log = RateLimitLog(
|
||||||
tenant_id=tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||||
operation="knowledge",
|
operation="knowledge",
|
||||||
)
|
)
|
||||||
@ -200,17 +188,15 @@ def cloud_utm_record(view: Callable[P, R]):
|
|||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
features = FeatureService.get_features(current_tenant_id)
|
||||||
tenant_id = account.current_tenant_id
|
|
||||||
features = FeatureService.get_features(tenant_id)
|
|
||||||
|
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
utm_info = request.cookies.get("utm_info")
|
utm_info = request.cookies.get("utm_info")
|
||||||
|
|
||||||
if utm_info:
|
if utm_info:
|
||||||
utm_info_dict: dict = json.loads(utm_info)
|
utm_info_dict: dict = json.loads(utm_info)
|
||||||
OperationService.record_utm(tenant_id, utm_info_dict)
|
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||||
|
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
@ -260,9 +246,9 @@ def email_password_login_enabled(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def email_register_enabled(view):
|
def email_register_enabled(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
if features.is_allow_register:
|
if features.is_allow_register:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
@ -289,9 +275,8 @@ def enable_change_email(view: Callable[P, R]):
|
|||||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
features = FeatureService.get_features(current_tenant_id)
|
||||||
features = FeatureService.get_features(account.current_tenant_id)
|
|
||||||
if features.is_allow_transfer_workspace:
|
if features.is_allow_transfer_workspace:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
@ -301,14 +286,31 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def knowledge_pipeline_publish_enabled(view):
|
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
account = _current_account()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
assert account.current_tenant_id is not None
|
features = FeatureService.get_features(current_tenant_id)
|
||||||
features = FeatureService.get_features(account.current_tenant_id)
|
|
||||||
if features.knowledge_pipeline.publish_enabled:
|
if features.knowledge_pipeline.publish_enabled:
|
||||||
return view(*args, **kwargs)
|
return view(*args, **kwargs)
|
||||||
abort(403)
|
abort(403)
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
def edit_permission_required(f: Callable[P, R]):
|
||||||
|
@wraps(f)
|
||||||
|
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from libs.login import current_user
|
||||||
|
from models import Account
|
||||||
|
|
||||||
|
user = current_user._get_current_object() # type: ignore
|
||||||
|
if not isinstance(user, Account):
|
||||||
|
raise Forbidden()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
|
raise Forbidden()
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated_function
|
||||||
|
|||||||
@ -46,11 +46,13 @@ class FilePreviewApi(Resource):
|
|||||||
def get(self, file_id):
|
def get(self, file_id):
|
||||||
file_id = str(file_id)
|
file_id = str(file_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("timestamp", type=str, required=True, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("nonce", type=str, required=True, location="args")
|
.add_argument("timestamp", type=str, required=True, location="args")
|
||||||
parser.add_argument("sign", type=str, required=True, location="args")
|
.add_argument("nonce", type=str, required=True, location="args")
|
||||||
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
.add_argument("sign", type=str, required=True, location="args")
|
||||||
|
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@ -16,12 +16,13 @@ class ToolFileApi(Resource):
|
|||||||
def get(self, file_id, extension):
|
def get(self, file_id, extension):
|
||||||
file_id = str(file_id)
|
file_id = str(file_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
|
reqparse.RequestParser()
|
||||||
parser.add_argument("timestamp", type=str, required=True, location="args")
|
.add_argument("timestamp", type=str, required=True, location="args")
|
||||||
parser.add_argument("nonce", type=str, required=True, location="args")
|
.add_argument("nonce", type=str, required=True, location="args")
|
||||||
parser.add_argument("sign", type=str, required=True, location="args")
|
.add_argument("sign", type=str, required=True, location="args")
|
||||||
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not verify_tool_file_signature(
|
if not verify_tool_file_signature(
|
||||||
|
|||||||
@ -18,19 +18,17 @@ from core.tools.tool_file_manager import ToolFileManager
|
|||||||
from fields.file_fields import build_file_model
|
from fields.file_fields import build_file_model
|
||||||
|
|
||||||
# Define parser for both documentation and validation
|
# Define parser for both documentation and validation
|
||||||
upload_parser = reqparse.RequestParser()
|
upload_parser = (
|
||||||
upload_parser.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload")
|
reqparse.RequestParser()
|
||||||
upload_parser.add_argument(
|
.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload")
|
||||||
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification"
|
.add_argument(
|
||||||
|
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification"
|
||||||
|
)
|
||||||
|
.add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification")
|
||||||
|
.add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation")
|
||||||
|
.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
|
||||||
|
.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
|
||||||
)
|
)
|
||||||
upload_parser.add_argument(
|
|
||||||
"nonce", type=str, required=True, location="args", help="Random string for signature verification"
|
|
||||||
)
|
|
||||||
upload_parser.add_argument(
|
|
||||||
"sign", type=str, required=True, location="args", help="HMAC signature for request validation"
|
|
||||||
)
|
|
||||||
upload_parser.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
|
|
||||||
upload_parser.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
|
|
||||||
|
|
||||||
|
|
||||||
@files_ns.route("/upload/for-plugin")
|
@files_ns.route("/upload/for-plugin")
|
||||||
|
|||||||
@ -5,11 +5,13 @@ from controllers.inner_api import inner_api_ns
|
|||||||
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
|
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
|
||||||
from tasks.mail_inner_task import send_inner_email_task
|
from tasks.mail_inner_task import send_inner_email_task
|
||||||
|
|
||||||
_mail_parser = reqparse.RequestParser()
|
_mail_parser = (
|
||||||
_mail_parser.add_argument("to", type=str, action="append", required=True)
|
reqparse.RequestParser()
|
||||||
_mail_parser.add_argument("subject", type=str, required=True)
|
.add_argument("to", type=str, action="append", required=True)
|
||||||
_mail_parser.add_argument("body", type=str, required=True)
|
.add_argument("subject", type=str, required=True)
|
||||||
_mail_parser.add_argument("substitutions", type=dict, required=False)
|
.add_argument("body", type=str, required=True)
|
||||||
|
.add_argument("substitutions", type=dict, required=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseMail(Resource):
|
class BaseMail(Resource):
|
||||||
@ -17,7 +19,7 @@ class BaseMail(Resource):
|
|||||||
|
|
||||||
def post(self):
|
def post(self):
|
||||||
args = _mail_parser.parse_args()
|
args = _mail_parser.parse_args()
|
||||||
send_inner_email_task.delay(
|
send_inner_email_task.delay( # type: ignore
|
||||||
to=args["to"],
|
to=args["to"],
|
||||||
subject=args["subject"],
|
subject=args["subject"],
|
||||||
body=args["body"],
|
body=args["body"],
|
||||||
|
|||||||
@ -31,7 +31,7 @@ from core.plugin.entities.request import (
|
|||||||
)
|
)
|
||||||
from core.tools.entities.tool_entities import ToolProviderType
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
from libs.helper import length_prefixed_response
|
from libs.helper import length_prefixed_response
|
||||||
from models.account import Account, Tenant
|
from models import Account, Tenant
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -72,9 +72,11 @@ def get_user_tenant(view: Callable[P, R] | None = None):
|
|||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||||
# fetch json body
|
# fetch json body
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("user_id", type=str, required=True, location="json")
|
.add_argument("tenant_id", type=str, required=True, location="json")
|
||||||
|
.add_argument("user_id", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
|
|
||||||
p = parser.parse_args()
|
p = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from controllers.inner_api import inner_api_ns
|
|||||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
|
|
||||||
|
|
||||||
@ -25,9 +25,11 @@ class EnterpriseWorkspace(Resource):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("owner_email", type=str, required=True, location="json")
|
.add_argument("name", type=str, required=True, location="json")
|
||||||
|
.add_argument("owner_email", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
account = db.session.query(Account).filter_by(email=args["owner_email"]).first()
|
account = db.session.query(Account).filter_by(email=args["owner_email"]).first()
|
||||||
@ -68,8 +70,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)
|
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)
|
||||||
|
|||||||
@ -33,14 +33,12 @@ def int_or_str(value):
|
|||||||
|
|
||||||
|
|
||||||
# Define parser for both documentation and validation
|
# Define parser for both documentation and validation
|
||||||
mcp_request_parser = reqparse.RequestParser()
|
mcp_request_parser = (
|
||||||
mcp_request_parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')"
|
.add_argument("jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')")
|
||||||
)
|
.add_argument("method", type=str, required=True, location="json", help="The method to invoke")
|
||||||
mcp_request_parser.add_argument("method", type=str, required=True, location="json", help="The method to invoke")
|
.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method")
|
||||||
mcp_request_parser.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method")
|
.add_argument("id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses")
|
||||||
mcp_request_parser.add_argument(
|
|
||||||
"id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -10,24 +10,24 @@ from controllers.service_api.wraps import validate_app_token
|
|||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from fields.annotation_fields import annotation_fields, build_annotation_model
|
from fields.annotation_fields import annotation_fields, build_annotation_model
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from models.model import App
|
from models.model import App
|
||||||
from services.annotation_service import AppAnnotationService
|
from services.annotation_service import AppAnnotationService
|
||||||
|
|
||||||
# Define parsers for annotation API
|
# Define parsers for annotation API
|
||||||
annotation_create_parser = reqparse.RequestParser()
|
annotation_create_parser = (
|
||||||
annotation_create_parser.add_argument("question", required=True, type=str, location="json", help="Annotation question")
|
reqparse.RequestParser()
|
||||||
annotation_create_parser.add_argument("answer", required=True, type=str, location="json", help="Annotation answer")
|
.add_argument("question", required=True, type=str, location="json", help="Annotation question")
|
||||||
|
.add_argument("answer", required=True, type=str, location="json", help="Annotation answer")
|
||||||
|
)
|
||||||
|
|
||||||
annotation_reply_action_parser = reqparse.RequestParser()
|
annotation_reply_action_parser = (
|
||||||
annotation_reply_action_parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching"
|
.add_argument(
|
||||||
)
|
"score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching"
|
||||||
annotation_reply_action_parser.add_argument(
|
)
|
||||||
"embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name"
|
.add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name")
|
||||||
)
|
.add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name")
|
||||||
annotation_reply_action_parser.add_argument(
|
|
||||||
"embedding_model_name", required=True, type=str, location="json", help="Embedding model name"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -85,11 +85,13 @@ class AudioApi(Resource):
|
|||||||
|
|
||||||
|
|
||||||
# Define parser for text-to-audio API
|
# Define parser for text-to-audio API
|
||||||
text_to_audio_parser = reqparse.RequestParser()
|
text_to_audio_parser = (
|
||||||
text_to_audio_parser.add_argument("message_id", type=str, required=False, location="json", help="Message ID")
|
reqparse.RequestParser()
|
||||||
text_to_audio_parser.add_argument("voice", type=str, location="json", help="Voice to use for TTS")
|
.add_argument("message_id", type=str, required=False, location="json", help="Message ID")
|
||||||
text_to_audio_parser.add_argument("text", type=str, location="json", help="Text to convert to audio")
|
.add_argument("voice", type=str, location="json", help="Voice to use for TTS")
|
||||||
text_to_audio_parser.add_argument("streaming", type=bool, location="json", help="Enable streaming response")
|
.add_argument("text", type=str, location="json", help="Text to convert to audio")
|
||||||
|
.add_argument("streaming", type=bool, location="json", help="Enable streaming response")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@service_api_ns.route("/text-to-audio")
|
@service_api_ns.route("/text-to-audio")
|
||||||
|
|||||||
@ -37,40 +37,34 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# Define parser for completion API
|
# Define parser for completion API
|
||||||
completion_parser = reqparse.RequestParser()
|
completion_parser = (
|
||||||
completion_parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"inputs", type=dict, required=True, location="json", help="Input parameters for completion"
|
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion")
|
||||||
)
|
.add_argument("query", type=str, location="json", default="", help="The query string")
|
||||||
completion_parser.add_argument("query", type=str, location="json", default="", help="The query string")
|
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
|
||||||
completion_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
|
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
|
||||||
completion_parser.add_argument(
|
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
|
||||||
"response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode"
|
|
||||||
)
|
|
||||||
completion_parser.add_argument(
|
|
||||||
"retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define parser for chat API
|
# Define parser for chat API
|
||||||
chat_parser = reqparse.RequestParser()
|
chat_parser = (
|
||||||
chat_parser.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat")
|
reqparse.RequestParser()
|
||||||
chat_parser.add_argument("query", type=str, required=True, location="json", help="The chat query")
|
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat")
|
||||||
chat_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
|
.add_argument("query", type=str, required=True, location="json", help="The chat query")
|
||||||
chat_parser.add_argument(
|
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
|
||||||
"response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode"
|
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
|
||||||
|
.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
|
||||||
|
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
|
||||||
|
.add_argument(
|
||||||
|
"auto_generate_name",
|
||||||
|
type=bool,
|
||||||
|
required=False,
|
||||||
|
default=True,
|
||||||
|
location="json",
|
||||||
|
help="Auto generate conversation name",
|
||||||
|
)
|
||||||
|
.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
|
||||||
)
|
)
|
||||||
chat_parser.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
|
|
||||||
chat_parser.add_argument(
|
|
||||||
"retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source"
|
|
||||||
)
|
|
||||||
chat_parser.add_argument(
|
|
||||||
"auto_generate_name",
|
|
||||||
type=bool,
|
|
||||||
required=False,
|
|
||||||
default=True,
|
|
||||||
location="json",
|
|
||||||
help="Auto generate conversation name",
|
|
||||||
)
|
|
||||||
chat_parser.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
|
|
||||||
|
|
||||||
|
|
||||||
@service_api_ns.route("/completion-messages")
|
@service_api_ns.route("/completion-messages")
|
||||||
|
|||||||
@ -24,48 +24,63 @@ from models.model import App, AppMode, EndUser
|
|||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
|
|
||||||
# Define parsers for conversation APIs
|
# Define parsers for conversation APIs
|
||||||
conversation_list_parser = reqparse.RequestParser()
|
conversation_list_parser = (
|
||||||
conversation_list_parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"last_id", type=uuid_value, location="args", help="Last conversation ID for pagination"
|
.add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination")
|
||||||
)
|
.add_argument(
|
||||||
conversation_list_parser.add_argument(
|
"limit",
|
||||||
"limit",
|
type=int_range(1, 100),
|
||||||
type=int_range(1, 100),
|
required=False,
|
||||||
required=False,
|
default=20,
|
||||||
default=20,
|
location="args",
|
||||||
location="args",
|
help="Number of conversations to return",
|
||||||
help="Number of conversations to return",
|
)
|
||||||
)
|
.add_argument(
|
||||||
conversation_list_parser.add_argument(
|
"sort_by",
|
||||||
"sort_by",
|
type=str,
|
||||||
type=str,
|
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
required=False,
|
||||||
required=False,
|
default="-updated_at",
|
||||||
default="-updated_at",
|
location="args",
|
||||||
location="args",
|
help="Sort order for conversations",
|
||||||
help="Sort order for conversations",
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_rename_parser = reqparse.RequestParser()
|
conversation_rename_parser = (
|
||||||
conversation_rename_parser.add_argument("name", type=str, required=False, location="json", help="New conversation name")
|
reqparse.RequestParser()
|
||||||
conversation_rename_parser.add_argument(
|
.add_argument("name", type=str, required=False, location="json", help="New conversation name")
|
||||||
"auto_generate", type=bool, required=False, default=False, location="json", help="Auto-generate conversation name"
|
.add_argument(
|
||||||
|
"auto_generate",
|
||||||
|
type=bool,
|
||||||
|
required=False,
|
||||||
|
default=False,
|
||||||
|
location="json",
|
||||||
|
help="Auto-generate conversation name",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_variables_parser = reqparse.RequestParser()
|
conversation_variables_parser = (
|
||||||
conversation_variables_parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"last_id", type=uuid_value, location="args", help="Last variable ID for pagination"
|
.add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination")
|
||||||
)
|
.add_argument(
|
||||||
conversation_variables_parser.add_argument(
|
"limit",
|
||||||
"limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of variables to return"
|
type=int_range(1, 100),
|
||||||
|
required=False,
|
||||||
|
default=20,
|
||||||
|
location="args",
|
||||||
|
help="Number of variables to return",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_variable_update_parser = reqparse.RequestParser()
|
conversation_variable_update_parser = reqparse.RequestParser().add_argument(
|
||||||
# using lambda is for passing the already-typed value without modification
|
# using lambda is for passing the already-typed value without modification
|
||||||
# if no lambda, it will be converted to string
|
# if no lambda, it will be converted to string
|
||||||
# the string cannot be converted using json.loads
|
# the string cannot be converted using json.loads
|
||||||
conversation_variable_update_parser.add_argument(
|
"value",
|
||||||
"value", required=True, location="json", type=lambda x: x, help="New value for the conversation variable"
|
required=True,
|
||||||
|
location="json",
|
||||||
|
type=lambda x: x,
|
||||||
|
help="New value for the conversation variable",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -18,8 +18,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# Define parser for file preview API
|
# Define parser for file preview API
|
||||||
file_preview_parser = reqparse.RequestParser()
|
file_preview_parser = reqparse.RequestParser().add_argument(
|
||||||
file_preview_parser.add_argument(
|
|
||||||
"as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment"
|
"as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -26,25 +26,37 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# Define parsers for message APIs
|
# Define parsers for message APIs
|
||||||
message_list_parser = reqparse.RequestParser()
|
message_list_parser = (
|
||||||
message_list_parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID"
|
.add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID")
|
||||||
)
|
.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination")
|
||||||
message_list_parser.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination")
|
.add_argument(
|
||||||
message_list_parser.add_argument(
|
"limit",
|
||||||
"limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of messages to return"
|
type=int_range(1, 100),
|
||||||
|
required=False,
|
||||||
|
default=20,
|
||||||
|
location="args",
|
||||||
|
help="Number of messages to return",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
message_feedback_parser = reqparse.RequestParser()
|
message_feedback_parser = (
|
||||||
message_feedback_parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating"
|
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating")
|
||||||
|
.add_argument("content", type=str, location="json", help="Feedback content")
|
||||||
)
|
)
|
||||||
message_feedback_parser.add_argument("content", type=str, location="json", help="Feedback content")
|
|
||||||
|
|
||||||
feedback_list_parser = reqparse.RequestParser()
|
feedback_list_parser = (
|
||||||
feedback_list_parser.add_argument("page", type=int, default=1, location="args", help="Page number")
|
reqparse.RequestParser()
|
||||||
feedback_list_parser.add_argument(
|
.add_argument("page", type=int, default=1, location="args", help="Page number")
|
||||||
"limit", type=int_range(1, 101), required=False, default=20, location="args", help="Number of feedbacks per page"
|
.add_argument(
|
||||||
|
"limit",
|
||||||
|
type=int_range(1, 101),
|
||||||
|
required=False,
|
||||||
|
default=20,
|
||||||
|
location="args",
|
||||||
|
help="Number of feedbacks per page",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -42,32 +42,36 @@ from services.workflow_app_service import WorkflowAppService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Define parsers for workflow APIs
|
# Define parsers for workflow APIs
|
||||||
workflow_run_parser = reqparse.RequestParser()
|
workflow_run_parser = (
|
||||||
workflow_run_parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
workflow_run_parser.add_argument("files", type=list, required=False, location="json")
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
workflow_run_parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
|
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
|
)
|
||||||
|
|
||||||
workflow_log_parser = reqparse.RequestParser()
|
workflow_log_parser = (
|
||||||
workflow_log_parser.add_argument("keyword", type=str, location="args")
|
reqparse.RequestParser()
|
||||||
workflow_log_parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
.add_argument("keyword", type=str, location="args")
|
||||||
workflow_log_parser.add_argument("created_at__before", type=str, location="args")
|
.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||||
workflow_log_parser.add_argument("created_at__after", type=str, location="args")
|
.add_argument("created_at__before", type=str, location="args")
|
||||||
workflow_log_parser.add_argument(
|
.add_argument("created_at__after", type=str, location="args")
|
||||||
"created_by_end_user_session_id",
|
.add_argument(
|
||||||
type=str,
|
"created_by_end_user_session_id",
|
||||||
location="args",
|
type=str,
|
||||||
required=False,
|
location="args",
|
||||||
default=None,
|
required=False,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
.add_argument(
|
||||||
|
"created_by_account",
|
||||||
|
type=str,
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||||
|
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||||
)
|
)
|
||||||
workflow_log_parser.add_argument(
|
|
||||||
"created_by_account",
|
|
||||||
type=str,
|
|
||||||
location="args",
|
|
||||||
required=False,
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
workflow_log_parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
|
||||||
workflow_log_parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
|
||||||
|
|
||||||
workflow_run_fields = {
|
workflow_run_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user