Merge branch 'main' into fix/secret_variable

This commit is contained in:
PankajKaushal 2025-10-16 13:26:07 +05:30 committed by GitHub
commit 7ba63d0c46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
363 changed files with 5985 additions and 3913 deletions

View File

@ -39,25 +39,11 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Run pyrefly check
run: |
cd api
uv add --dev pyrefly
uv run pyrefly check || true
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
@ -93,3 +79,19 @@ jobs:
- name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY

View File

@ -1,6 +1,7 @@
#!/bin/bash
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml

View File

@ -189,6 +189,11 @@ class PluginConfig(BaseSettings):
default="plugin-api-key",
)
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=300.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
PLUGIN_REMOTE_INSTALL_HOST: str = Field(

View File

@ -7,13 +7,12 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.model import ApiToken, App
from . import api, console_ns
from .wraps import account_initialization_required, setup_required
from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = {
"id": fields.String,
@ -57,9 +56,9 @@ class BaseApiKeyListResource(Resource):
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
@ -68,15 +67,12 @@ class BaseApiKeyListResource(Resource):
return {"items": keys}
@marshal_with(api_key_fields)
@edit_permission_required
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.has_edit_permission:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
@ -93,7 +89,7 @@ class BaseApiKeyListResource(Resource):
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id
api_token.tenant_id = current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
@ -112,9 +108,8 @@ class BaseApiKeyResource(Resource):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
api_key_id = str(api_key_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
@ -158,11 +153,6 @@ class AppApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for an app"""
return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app"
resource_model = App
resource_id_field = "app_id"
@ -179,11 +169,6 @@ class AppApiKeyResource(BaseApiKeyResource):
"""Delete an API key for an app"""
return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "app"
resource_model = App
resource_id_field = "app_id"
@ -208,11 +193,6 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
"""Create a new API key for a dataset"""
return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset"
resource_model = Dataset
resource_id_field = "dataset_id"
@ -229,11 +209,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
"""Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
return resp
resource_type = "dataset"
resource_model = Dataset
resource_id_field = "dataset_id"

View File

@ -1,15 +1,14 @@
from typing import Literal
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import api, console_ns
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
from extensions.ext_redis import redis_client
@ -42,10 +41,8 @@ class AnnotationReplyActionApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
@ -69,10 +66,8 @@ class AppAnnotationSettingDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
return result, 200
@ -98,10 +93,8 @@ class AppAnnotationSettingUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id, annotation_setting_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id)
@ -124,10 +117,8 @@ class AnnotationReplyActionStatusApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id, action):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
cache_result = redis_client.get(app_annotation_job_key)
@ -159,10 +150,8 @@ class AnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
@ -198,10 +187,8 @@ class AnnotationApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
@edit_permission_required
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
@ -213,10 +200,8 @@ class AnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
# Use request.args.getlist to get annotation_ids array directly
@ -249,10 +234,8 @@ class AnnotationExportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)}
@ -271,11 +254,9 @@ class AnnotationUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
@ -288,10 +269,8 @@ class AnnotationUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
@ -310,10 +289,8 @@ class AnnotationBatchImportApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
# check file
if "file" not in request.files:
@ -341,10 +318,8 @@ class AnnotationBatchImportStatusApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id):
if not current_user.is_editor:
raise Forbidden()
job_id = str(job_id)
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
cache_result = redis_client.get(indexing_cache_key)
@ -376,10 +351,8 @@ class AnnotationHitHistoryListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id, annotation_id):
if not current_user.is_editor:
raise Forbidden()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
app_id = str(app_id)

View File

@ -1,7 +1,5 @@
import uuid
from typing import cast
from flask_login import current_user
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -12,15 +10,16 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
enterprise_license_required,
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import Account, App
from models import App
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@ -56,6 +55,7 @@ class AppListApi(Resource):
@enterprise_license_required
def get(self):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
def uuid_list(value):
try:
@ -90,7 +90,7 @@ class AppListApi(Resource):
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@ -129,8 +129,10 @@ class AppListApi(Resource):
@account_initialization_required
@marshal_with(app_detail_fields)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
"""Create app"""
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("description", type=validate_description_length, location="json")
@ -140,19 +142,11 @@ class AppListApi(Resource):
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required")
app_service = AppService()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if current_user.current_tenant_id is None:
raise ValueError("current_user.current_tenant_id cannot be None")
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
app = app_service.create_app(current_tenant_id, args, current_user)
return app, 201
@ -205,13 +199,10 @@ class AppApi(Resource):
@login_required
@account_initialization_required
@get_app_model
@edit_permission_required
@marshal_with(app_detail_fields_with_site)
def put(self, app_model):
"""Update app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=validate_description_length, location="json")
@ -248,12 +239,9 @@ class AppApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_model):
"""Delete app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
app_service = AppService()
app_service.delete_app(app_model)
@ -283,12 +271,12 @@ class AppCopyApi(Resource):
@login_required
@account_initialization_required
@get_app_model
@edit_permission_required
@marshal_with(app_detail_fields_with_site)
def post(self, app_model):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, location="json")
@ -301,9 +289,8 @@ class AppCopyApi(Resource):
with Session(db.engine) as session:
import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
account=current_user,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
name=args.get("name"),
@ -340,12 +327,9 @@ class AppExportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_model):
"""Export app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
@ -371,11 +355,8 @@ class AppNameApi(Resource):
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
@ -408,11 +389,8 @@ class AppIconApi(Resource):
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
@ -441,11 +419,8 @@ class AppSiteStatus(Resource):
@account_initialization_required
@get_app_model
@marshal_with(app_detail_fields)
@edit_permission_required
def post(self, app_model):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enable_site", type=bool, required=True, location="json")
args = parser.parse_args()
@ -475,6 +450,7 @@ class AppApiStatus(Resource):
@marshal_with(app_detail_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@ -520,10 +496,9 @@ class AppTraceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id):
# add app trace
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("enabled", type=bool, required=True, location="json")
parser.add_argument("tracing_provider", type=str, required=True, location="json")

View File

@ -1,20 +1,16 @@
from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService
@ -30,11 +26,10 @@ class AppImportApi(Resource):
@account_initialization_required
@marshal_with(app_import_fields)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
@ -51,7 +46,7 @@ class AppImportApi(Resource):
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = cast(Account, current_user)
account = current_user
result = import_service.import_app(
account=account,
import_mode=args["mode"],
@ -83,16 +78,16 @@ class AppImportConfirmApi(Resource):
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
@edit_permission_required
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Confirm import
account = cast(Account, current_user)
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
@ -109,10 +104,8 @@ class AppImportCheckDependenciesApi(Resource):
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_fields)
@edit_permission_required
def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)

View File

@ -2,7 +2,7 @@ import logging
from flask import request
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.console import api, console_ns
@ -15,7 +15,7 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
@ -151,13 +151,8 @@ class ChatMessageApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")

View File

@ -1,17 +1,16 @@
from datetime import datetime
import pytz # pip install pytz
import pytz
import sqlalchemy as sa
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import Forbidden, NotFound
from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
@ -22,8 +21,8 @@ from fields.conversation_fields import (
)
from libs.datetime_utils import naive_utc_now
from libs.helper import DatetimeString
from libs.login import login_required
from models import Account, Conversation, EndUser, Message, MessageAnnotation
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
@ -57,9 +56,9 @@ class CompletionConversationApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_fields)
@edit_permission_required
def get(self, app_model):
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@ -84,6 +83,7 @@ class CompletionConversationApi(Resource):
)
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -137,9 +137,8 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_fields)
@edit_permission_required
def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id)
@ -154,14 +153,12 @@ class CompletionConversationDetailApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def delete(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
conversation_id = str(conversation_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -206,9 +203,9 @@ class ChatConversationApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_with_summary_pagination_fields)
@edit_permission_required
def get(self, app_model):
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@ -260,6 +257,7 @@ class ChatConversationApi(Resource):
)
account = current_user
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -341,9 +339,8 @@ class ChatConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_detail_fields)
@edit_permission_required
def get(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id)
@ -358,14 +355,12 @@ class ChatConversationDetailApi(Resource):
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@edit_permission_required
def delete(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()
current_user, _ = current_account_with_tenant()
conversation_id = str(conversation_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -374,6 +369,7 @@ class ChatConversationDetailApi(Resource):
def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation = (
db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)

View File

@ -1,6 +1,5 @@
from collections.abc import Sequence
from flask_login import current_user
from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns
@ -17,7 +16,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import App
from services.workflow_service import WorkflowService
@ -48,11 +47,11 @@ class RuleGenerateApi(Resource):
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=account.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=args["no_variable"],
@ -99,11 +98,11 @@ class RuleCodeGenerateApi(Resource):
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
code_result = LLMGenerator.generate_code(
tenant_id=account.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["code_language"],
@ -144,11 +143,11 @@ class RuleStructuredOutputGenerateApi(Resource):
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
account = current_user
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=account.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
)
@ -198,6 +197,7 @@ class InstructionGenerateApi(Resource):
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
code_template = (
Python3CodeProvider.get_default_code()
if args["language"] == "python"
@ -222,21 +222,21 @@ class InstructionGenerateApi(Resource):
match node_type:
case "llm":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_user.current_tenant_id,
current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["language"],
@ -245,7 +245,7 @@ class InstructionGenerateApi(Resource):
return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
flow_id=args["flow_id"],
current=args["current"],
instruction=args["instruction"],
@ -254,7 +254,7 @@ class InstructionGenerateApi(Resource):
)
if args["node_id"] != "" and args["current"] != "": # For workflow node
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
flow_id=args["flow_id"],
node_id=args["node_id"],
current=args["current"],

View File

@ -1,16 +1,15 @@
import json
from enum import StrEnum
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.app_fields import app_server_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer
@ -25,9 +24,9 @@ class AppMCPServerController(Resource):
@api.doc(description="Get MCP server configuration for an application")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
@setup_required
@login_required
@account_initialization_required
@setup_required
@get_app_model
@marshal_with(app_server_fields)
def get(self, app_model):
@ -48,14 +47,14 @@ class AppMCPServerController(Resource):
)
@api.response(201, "MCP server configuration created successfully", app_server_fields)
@api.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@login_required
@setup_required
@marshal_with(app_server_fields)
@edit_permission_required
def post(self, app_model):
if not current_user.is_editor:
raise NotFound()
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
@ -71,7 +70,7 @@ class AppMCPServerController(Resource):
parameters=json.dumps(args["parameters"], ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
server_code=AppMCPServer.generate_server_code(16),
)
db.session.add(server)
@ -95,14 +94,13 @@ class AppMCPServerController(Resource):
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
@api.response(403, "Insufficient permissions")
@api.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model
@login_required
@setup_required
@account_initialization_required
@marshal_with(app_server_fields)
@edit_permission_required
def put(self, app_model):
if not current_user.is_editor:
raise NotFound()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=False, location="json")
@ -142,13 +140,13 @@ class AppMCPServerRefreshController(Resource):
@login_required
@account_initialization_required
@marshal_with(app_server_fields)
@edit_permission_required
def get(self, server_id):
if not current_user.is_editor:
raise NotFound()
_, current_tenant_id = current_account_with_tenant()
server = (
db.session.query(AppMCPServer)
.where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_user.current_tenant_id)
.where(AppMCPServer.tenant_id == current_tenant_id)
.first()
)
if not server:

View File

@ -3,7 +3,7 @@ import logging
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import exists, select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.console import api, console_ns
from controllers.console.app.error import (
@ -17,6 +17,7 @@ from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDi
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
)
from core.app.entities.app_invoke_entities import InvokeFrom
@ -26,8 +27,7 @@ from extensions.ext_database import db
from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
@ -56,15 +56,13 @@ class ChatMessageListApi(Resource):
)
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
@api.response(404, "Conversation not found")
@setup_required
@login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(message_infinite_scroll_pagination_fields)
@edit_permission_required
def get(self, app_model):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
parser.add_argument("first_id", type=uuid_value, location="args")
@ -154,8 +152,7 @@ class MessageFeedbackApi(Resource):
@login_required
@account_initialization_required
def post(self, app_model):
if current_user is None:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
@ -211,18 +208,14 @@ class MessageAnnotationApi(Resource):
)
@api.response(200, "Annotation created successfully", annotation_fields)
@api.response(403, "Insufficient permissions")
@marshal_with(annotation_fields)
@get_app_model
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@get_app_model
@marshal_with(annotation_fields)
@account_initialization_required
@edit_permission_required
def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
parser.add_argument("question", required=True, type=str, location="json")
@ -270,6 +263,7 @@ class MessageSuggestedQuestionApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model, message_id):
current_user, _ = current_account_with_tenant()
message_id = str(message_id)
try:
@ -304,12 +298,12 @@ class MessageApi(Resource):
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@api.response(200, "Message retrieved successfully", message_detail_fields)
@api.response(404, "Message not found")
@get_app_model
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(message_detail_fields)
def get(self, app_model, message_id):
def get(self, app_model, message_id: str):
message_id = str(message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()

View File

@ -2,7 +2,6 @@ import json
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
@ -15,8 +14,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
@ -54,16 +52,14 @@ class ModelConfigResource(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model):
"""Modify app model config"""
if not isinstance(current_user, Account):
raise Forbidden()
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
config=cast(dict, request.json),
app_mode=AppMode.value_of(app_model.mode),
)
@ -95,12 +91,12 @@ class ModelConfigResource(Resource):
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
@ -134,7 +130,7 @@ class ModelConfigResource(Resource):
else:
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
@ -142,7 +138,7 @@ class ModelConfigResource(Resource):
continue
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
@ -9,8 +8,8 @@ from controllers.console.wraps import account_initialization_required, setup_req
from extensions.ext_database import db
from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Account, Site
from libs.login import current_account_with_tenant, login_required
from models import Site
def parse_app_site_args():
@ -76,9 +75,10 @@ class AppSite(Resource):
@marshal_with(app_site_fields)
def post(self, app_model):
args = parse_app_site_args()
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
@ -107,8 +107,6 @@ class AppSite(Resource):
if value is not None:
setattr(site, attr_name, value)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id
site.updated_at = naive_utc_now()
db.session.commit()
@ -131,6 +129,8 @@ class AppSiteAccessTokenReset(Resource):
@marshal_with(app_site_fields)
def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@ -140,8 +140,6 @@ class AppSiteAccessTokenReset(Resource):
raise NotFound
site.code = Site.generate_code(16)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
site.updated_by = current_user.id
site.updated_at = naive_utc_now()
db.session.commit()

View File

@ -4,7 +4,6 @@ from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, fields, reqparse
from controllers.console import api, console_ns
@ -13,7 +12,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import AppMode, Message
@ -37,7 +36,7 @@ class DailyMessageStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@ -53,6 +52,7 @@ WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -109,13 +109,13 @@ class DailyConversationStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -175,7 +175,7 @@ class DailyTerminalsStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@ -191,7 +191,7 @@ WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -247,7 +247,7 @@ class DailyTokenCostStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@ -264,7 +264,7 @@ WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -322,7 +322,7 @@ class AverageSessionInteractionStatistic(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@ -346,7 +346,7 @@ FROM
c.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -413,7 +413,7 @@ class UserSatisfactionRateStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@ -433,7 +433,7 @@ WHERE
m.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -494,7 +494,7 @@ class AverageResponseTimeStatistic(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@ -510,7 +510,7 @@ WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -566,7 +566,7 @@ class TokensPerSecondStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@ -585,7 +585,7 @@ WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc

View File

@ -12,7 +12,7 @@ import services
from controllers.console import api, console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -27,9 +27,8 @@ from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from libs.login import current_account_with_tenant, login_required
from models import App
from models.account import Account
from models.model import AppMode
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
@ -70,15 +69,11 @@ class DraftWorkflowApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
@edit_permission_required
def get(self, app_model: App):
"""
Get draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model)
@ -110,14 +105,12 @@ class DraftWorkflowApi(Resource):
@api.response(200, "Draft workflow synced successfully", workflow_fields)
@api.response(400, "Invalid workflow configuration")
@api.response(403, "Permission denied")
@edit_permission_required
def post(self, app_model: App):
"""
Sync draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
content_type = request.headers.get("Content-Type", "")
@ -149,10 +142,6 @@ class DraftWorkflowApi(Resource):
return {"message": "Invalid JSON data"}, 400
else:
abort(415)
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService()
try:
@ -206,17 +195,12 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App):
"""
Run draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
@ -271,16 +255,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@ -323,16 +303,12 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@ -375,17 +351,12 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow loop node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@ -428,17 +399,12 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow loop node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@ -480,17 +446,12 @@ class DraftWorkflowRunApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Run draft workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
@ -526,17 +487,11 @@ class WorkflowTaskStopApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, task_id: str):
"""
Stop workflow task
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
@ -568,17 +523,12 @@ class DraftWorkflowNodeRunApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_fields)
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Run draft workflow node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("query", type=str, required=False, location="json", default="")
@ -622,17 +572,11 @@ class PublishedWorkflowApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
@edit_permission_required
def get(self, app_model: App):
"""
Get published workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# fetch published workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app_model)
@ -644,16 +588,12 @@ class PublishedWorkflowApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Publish workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
@ -702,17 +642,11 @@ class DefaultBlockConfigsApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App):
"""
Get default block config
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs
workflow_service = WorkflowService()
return workflow_service.get_default_block_configs()
@ -729,16 +663,11 @@ class DefaultBlockConfigApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def get(self, app_model: App, block_type: str):
"""
Get default block config
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args")
args = parser.parse_args()
@ -769,17 +698,14 @@ class ConvertToWorkflowApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
@edit_permission_required
def post(self, app_model: App):
"""
Convert basic mode of chatbot app to workflow mode
Convert expert mode of chatbot app to workflow mode
Convert Completion App to Workflow App
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
if request.data:
parser = reqparse.RequestParser()
@ -812,15 +738,12 @@ class PublishedAllWorkflowApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_pagination_fields)
@edit_permission_required
def get(self, app_model: App):
"""
Get published workflows
"""
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
@ -879,16 +802,12 @@ class WorkflowByIdApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
@edit_permission_required
def patch(self, app_model: App, workflow_id: str):
"""
Update workflow attributes
"""
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission
if not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, location="json")
parser.add_argument("marked_comment", type=str, required=False, location="json")
@ -934,16 +853,11 @@ class WorkflowByIdApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def delete(self, app_model: App, workflow_id: str):
"""
Delete workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission
if not current_user.has_edit_permission:
raise Forbidden()
workflow_service = WorkflowService()
# Create a session and manage the transaction

View File

@ -22,8 +22,7 @@ from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import App, AppMode
from models.account import Account
from models import Account, App, AppMode
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService

View File

@ -1,6 +1,5 @@
from typing import cast
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
@ -14,7 +13,7 @@ from fields.workflow_run_fields import (
workflow_run_pagination_fields,
)
from libs.helper import uuid_value
from libs.login import login_required
from libs.login import current_user, login_required
from models import Account, App, AppMode, EndUser
from services.workflow_run_service import WorkflowRunService

View File

@ -4,7 +4,6 @@ from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restx import Resource, reqparse
from controllers.console import api, console_ns
@ -12,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
@ -29,7 +28,7 @@ class WorkflowDailyRunsStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@ -49,7 +48,7 @@ WHERE
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -97,7 +96,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@ -117,7 +116,7 @@ WHERE
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -165,7 +164,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@ -185,7 +184,7 @@ WHERE
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -238,7 +237,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def get(self, app_model):
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
@ -271,7 +270,7 @@ GROUP BY
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
assert account.timezone is not None
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc

View File

@ -4,28 +4,29 @@ from typing import ParamSpec, TypeVar, Union
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from libs.login import current_account_with_tenant
from models import App, AppMode
from models.account import Account
P = ParamSpec("P")
R = TypeVar("R")
P1 = ParamSpec("P1")
R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None:
assert isinstance(current_user, Account)
_, current_tenant_id = current_account_with_tenant()
app_model = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
)
return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")

View File

@ -7,7 +7,7 @@ from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus
from models import AccountStatus
from services.account_service import AccountService, RegisterService
active_check_parser = reqparse.RequestParser()

View File

@ -1,10 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required
@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource):
@login_required
@account_initialization_required
def get(self):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
if data_source_api_key_bindings:
return {
"sources": [
@ -41,6 +41,8 @@ class ApiKeyAuthDataSourceBinding(Resource):
@account_initialization_required
def post(self):
# The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
@ -50,7 +52,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200
@ -63,9 +65,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@account_initialization_required
def delete(self, binding_id):
# The role of the current user in the table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
return {"result": "success"}, 204

View File

@ -2,13 +2,12 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_login import current_user
from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api, console_ns
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required
@ -45,6 +44,7 @@ class OAuthDataSource(Resource):
@api.response(403, "Admin privileges required")
def get(self, provider: str):
# The role of the current user in the table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()

View File

@ -19,7 +19,7 @@ from controllers.console.wraps import email_password_login_enabled, email_regist
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import valid_password
from models.account import Account
from models import Account
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError

View File

@ -20,7 +20,7 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService

View File

@ -1,5 +1,3 @@
from typing import cast
import flask_login
from flask import request
from flask_restx import Resource, reqparse
@ -26,7 +24,7 @@ from controllers.console.error import (
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from models.account import Account
from libs.login import current_account_with_tenant
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.errors.account import AccountRegisterError
@ -96,7 +94,8 @@ class LoginApi(Resource):
class LogoutApi(Resource):
@setup_required
def get(self):
account = cast(Account, flask_login.current_user)
current_user, _ = current_account_with_tenant()
account = current_user
if isinstance(account, flask_login.AnonymousUserMixin):
return {"result": "success"}
AccountService.logout(account=account)

View File

@ -14,8 +14,7 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
from models.account import AccountStatus
from models import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError

View File

@ -1,16 +1,15 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast
from typing import Concatenate, ParamSpec, TypeVar
import flask_login
from flask import jsonify, request
from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models import Account
from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
@ -116,7 +115,8 @@ class OAuthServerUserAuthorizeApi(Resource):
@account_initialization_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
account = cast(Account, flask_login.current_user)
current_user, _ = current_account_with_tenant()
account = current_user
user_account_id = account.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)

View File

@ -2,8 +2,7 @@ from flask_restx import Resource, reqparse
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_user, login_required
from models.model import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
@ -14,17 +13,13 @@ class Subscription(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
args = parser.parse_args()
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_subscription(
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
)
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
@console_ns.route("/billing/invoices")
@ -34,7 +29,6 @@ class Invoices(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
assert isinstance(current_user, Account)
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
return BillingService.get_invoices(current_user.email, current_tenant_id)

View File

@ -2,8 +2,7 @@ from flask import request
from flask_restx import Resource, reqparse
from libs.helper import extract_remote_ip
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from .. import console_ns
@ -17,19 +16,17 @@ class ComplianceApi(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("doc_name", type=str, required=True, location="args")
args = parser.parse_args()
ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device")
return BillingService.get_compliance_download_link(
doc_name=args.doc_name,
account_id=current_user.id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
ip=ip_address,
device_info=device_info,
)

View File

@ -3,7 +3,6 @@ from collections.abc import Generator
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -20,7 +19,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
@ -37,10 +36,12 @@ class DataSourceApi(Resource):
@account_initialization_required
@marshal_with(integrate_list_fields)
def get(self):
_, current_tenant_id = current_account_with_tenant()
# get workspace data source integrates
data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.tenant_id == current_tenant_id,
DataSourceOauthBinding.disabled == False,
)
).all()
@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource):
@account_initialization_required
@marshal_with(integrate_notion_info_list_fields)
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = request.args.get("dataset_id", default=None, type=str)
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource):
documents = session.scalars(
select(Document).filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource):
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion_datasource",
datasource_name="notion_datasource",
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
datasource_provider_service = DatasourceProviderService()
@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def get(self, workspace_id, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource):
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
)
text_docs = extractor.extract()
@ -239,6 +244,8 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def post(self):
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
@ -263,7 +270,7 @@ class DataSourceNotionApi(Resource):
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
}
),
document_model=args["doc_form"],
@ -271,7 +278,7 @@ class DataSourceNotionApi(Resource):
extract_settings.append(extract_setting)
indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
extract_settings,
args["process_rule"],
args["doc_form"],

View File

@ -1,7 +1,6 @@
from typing import Any, cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
@ -30,10 +29,9 @@ from extensions.ext_database import db
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -138,6 +136,7 @@ class DatasetListApi(Resource):
@account_initialization_required
@enterprise_license_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
@ -146,15 +145,15 @@ class DatasetListApi(Resource):
tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
page, limit, current_tenant_id, current_user, search, tag_ids, include_all
)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@ -251,6 +250,7 @@ class DatasetListApi(Resource):
required=False,
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@ -258,11 +258,11 @@ class DatasetListApi(Resource):
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
account=cast(Account, current_user),
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
@ -286,6 +286,7 @@ class DatasetApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -305,7 +306,7 @@ class DatasetApi(Resource):
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@ -418,6 +419,7 @@ class DatasetApi(Resource):
)
args = parser.parse_args()
data = request.get_json()
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
@ -440,7 +442,7 @@ class DatasetApi(Resource):
raise NotFound("Dataset not found.")
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
@ -464,9 +466,10 @@ class DatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id):
dataset_id_str = str(dataset_id)
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_operator):
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
try:
@ -505,6 +508,7 @@ class DatasetQueryApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -556,15 +560,14 @@ class DatasetIndexingEstimateApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
# validate args
DocumentService.estimate_args_validate(args)
extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = db.session.scalars(
select(UploadFile).where(
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
)
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
).all()
if file_details is None:
@ -592,7 +595,7 @@ class DatasetIndexingEstimateApi(Resource):
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
}
),
document_model=args["doc_form"],
@ -608,7 +611,7 @@ class DatasetIndexingEstimateApi(Resource):
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
}
@ -621,7 +624,7 @@ class DatasetIndexingEstimateApi(Resource):
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
extract_settings,
args["process_rule"],
args["doc_form"],
@ -652,6 +655,7 @@ class DatasetRelatedAppListApi(Resource):
@account_initialization_required
@marshal_with(related_app_list)
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -683,11 +687,10 @@ class DatasetIndexingStatusApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
)
select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
).all()
documents_status = []
for document in documents:
@ -739,10 +742,9 @@ class DatasetApiKeyApi(Resource):
@account_initialization_required
@marshal_with(api_key_list)
def get(self):
_, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
)
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
).all()
return {"items": keys}
@ -752,12 +754,13 @@ class DatasetApiKeyApi(Resource):
@marshal_with(api_key_fields)
def post(self):
# The role of the current user in the ta table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
.count()
)
@ -770,7 +773,7 @@ class DatasetApiKeyApi(Resource):
key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken()
api_token.tenant_id = current_user.current_tenant_id
api_token.tenant_id = current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
@ -790,6 +793,7 @@ class DatasetApiDeleteApi(Resource):
@login_required
@account_initialization_required
def delete(self, api_key_id):
current_user, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner
@ -799,7 +803,7 @@ class DatasetApiDeleteApi(Resource):
key = (
db.session.query(ApiToken)
.where(
ApiToken.tenant_id == current_user.current_tenant_id,
ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
@ -898,6 +902,7 @@ class DatasetPermissionUserListApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -6,7 +6,6 @@ from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
@ -53,9 +52,8 @@ from fields.document_fields import (
document_with_segments_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
@ -65,6 +63,7 @@ logger = logging.getLogger(__name__)
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -79,12 +78,13 @@ class DocumentResource(Resource):
if not document:
raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id:
if document.tenant_id != current_tenant_id:
raise Forbidden("No permission.")
return document
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
current_user, _ = current_account_with_tenant()
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -112,6 +112,7 @@ class GetProcessRuleApi(Resource):
@login_required
@account_initialization_required
def get(self):
current_user, _ = current_account_with_tenant()
req_data = request.args
document_id = req_data.get("document_id")
@ -168,6 +169,7 @@ class DatasetDocumentListApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
@ -199,7 +201,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
if search:
search = f"%{search}%"
@ -273,6 +275,7 @@ class DatasetDocumentListApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -372,6 +375,7 @@ class DatasetInitApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
@ -402,7 +406,7 @@ class DatasetInitApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=args["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=args["embedding_model"],
@ -419,9 +423,9 @@ class DatasetInitApi(Resource):
try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
knowledge_config=knowledge_config,
account=cast(Account, current_user),
account=current_user,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -447,6 +451,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
@ -482,7 +487,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
try:
estimate_response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
[extract_setting],
data_process_rule_dict,
document.doc_form,
@ -511,6 +516,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@login_required
@account_initialization_required
def get(self, dataset_id, batch):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch)
@ -530,7 +536,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
@ -553,7 +559,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
@ -569,7 +575,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
@ -583,7 +589,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
extract_settings,
data_process_rule_dict,
document.doc_form,
@ -834,6 +840,7 @@ class DocumentProcessingApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
@ -884,6 +891,7 @@ class DocumentMetadataApi(DocumentResource):
@login_required
@account_initialization_required
def put(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
@ -931,6 +939,7 @@ class DocumentStatusApi(DocumentResource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
@ -1077,12 +1086,13 @@ class DocumentRenameApi(DocumentResource):
@marshal_with(document_fields)
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset)
DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
@ -1102,6 +1112,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
@account_initialization_required
def get(self, dataset_id, document_id):
"""sync website document."""
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
@ -1110,7 +1121,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id:
if document.tenant_id != current_tenant_id:
raise Forbidden("No permission.")
if document.data_source_type != "website_crawl":
raise ValueError("Document is not a website document.")

View File

@ -1,7 +1,6 @@
import uuid
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService
@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -79,7 +80,7 @@ class DatasetDocumentSegmentListApi(Resource):
select(DocumentSegment)
.where(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id,
DocumentSegment.tenant_id == current_tenant_id,
)
.order_by(DocumentSegment.position.asc())
)
@ -115,6 +116,8 @@ class DatasetDocumentSegmentListApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -148,6 +151,8 @@ class DatasetDocumentSegmentApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
@ -171,7 +176,7 @@ class DatasetDocumentSegmentApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -204,6 +209,8 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -221,7 +228,7 @@ class DatasetDocumentSegmentAddApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -255,6 +262,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -272,7 +281,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -287,7 +296,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -317,6 +326,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -333,7 +344,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -361,6 +372,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -396,7 +409,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
upload_file_id,
dataset_id,
document_id,
current_user.current_tenant_id,
current_tenant_id,
current_user.id,
)
except Exception as e:
@ -427,6 +440,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -441,7 +456,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -453,7 +468,7 @@ class ChildChunkAddApi(Resource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -483,6 +498,8 @@ class ChildChunkAddApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id, document_id, segment_id):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -499,7 +516,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -530,6 +547,8 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -546,7 +565,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -580,6 +599,8 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -596,7 +617,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -607,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)
@ -634,6 +655,8 @@ class ChildChunkUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
@ -650,7 +673,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.first()
)
if not segment:
@ -661,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
db.session.query(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id,
)

View File

@ -1,7 +1,4 @@
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -10,8 +7,7 @@ from controllers.console import api, console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, setup_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
@ -40,12 +36,13 @@ class ExternalApiTemplateListApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
page, limit, current_user.current_tenant_id, search
page, limit, current_tenant_id, search
)
response = {
"data": [item.to_dict() for item in external_knowledge_apis],
@ -60,6 +57,7 @@ class ExternalApiTemplateListApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument(
"name",
@ -85,7 +83,7 @@ class ExternalApiTemplateListApi(Resource):
try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args
tenant_id=current_tenant_id, user_id=current_user.id, args=args
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@ -115,6 +113,7 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
parser = reqparse.RequestParser()
@ -136,7 +135,7 @@ class ExternalApiTemplateApi(Resource):
ExternalDatasetService.validate_api_list(args["settings"])
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id,
args=args,
@ -148,13 +147,14 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
def delete(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_operator):
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
return {"result": "success"}, 204
@ -199,7 +199,8 @@ class ExternalDatasetCreateApi(Resource):
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -223,7 +224,7 @@ class ExternalDatasetCreateApi(Resource):
try:
dataset = ExternalDatasetService.create_external_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
user_id=current_user.id,
args=args,
)
@ -255,6 +256,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
@login_required
@account_initialization_required
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -277,7 +279,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
response = HitTestingService.external_retrieve(
dataset=dataset,
query=args["query"],
account=cast(Account, current_user),
account=current_user,
external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"],
)

View File

@ -1,13 +1,12 @@
from typing import Literal
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
@ -24,6 +23,7 @@ class DatasetMetadataCreateApi(Resource):
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
@ -59,6 +59,7 @@ class DatasetMetadataApi(Resource):
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
def patch(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
@ -79,6 +80,7 @@ class DatasetMetadataApi(Resource):
@account_initialization_required
@enterprise_license_required
def delete(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -108,6 +110,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -128,6 +131,7 @@ class DocumentMetadataEditApi(Resource):
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -1,19 +1,15 @@
from flask import make_response, redirect, request
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
@ -24,11 +20,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, provider_id: str):
user = current_user
tenant_id = user.current_tenant_id
if not current_user.is_editor:
raise Forbidden()
current_user, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
credential_id = request.args.get("credential_id")
datasource_provider_id = DatasourceProviderID(provider_id)
@ -52,7 +48,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
user_id=current_user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
@ -130,9 +126,9 @@ class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument(
@ -145,7 +141,7 @@ class DatasourceAuth(Resource):
try:
datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider_id=datasource_provider_id,
credentials=args["credentials"],
name=args["name"],
@ -160,8 +156,10 @@ class DatasourceAuth(Resource):
def get(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
_, current_tenant_id = current_account_with_tenant()
datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
@ -173,18 +171,20 @@ class DatasourceAuthDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
auth_id=args["credential_id"],
provider=provider_name,
plugin_id=plugin_id,
@ -197,18 +197,20 @@ class DatasourceAuthUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
auth_id=args["credential_id"],
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
@ -224,10 +226,10 @@ class DatasourceAuthListApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_all_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@ -237,10 +239,10 @@ class DatasourceHardCodeAuthListApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@ -249,9 +251,10 @@ class DatasourceAuthOauthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
@ -259,7 +262,7 @@ class DatasourceAuthOauthCustomClient(Resource):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}),
enabled=args.get("enable_oauth_custom_client", False),
@ -270,10 +273,12 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required
@account_initialization_required
def delete(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
)
return {"result": "success"}, 200
@ -284,16 +289,17 @@ class DatasourceAuthDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
credential_id=args["id"],
)
@ -305,9 +311,10 @@ class DatasourceUpdateProviderNameApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
@ -315,7 +322,7 @@ class DatasourceUpdateProviderNameApi(Resource):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
name=args["name"],
credential_id=args["credential_id"],

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, marshal, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@ -13,7 +12,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@ -38,7 +37,7 @@ class CreateRagPipelineDatasetApi(Resource):
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
@ -58,12 +57,12 @@ class CreateRagPipelineDatasetApi(Resource):
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id,
current_tenant_id,
import_info["dataset_id"],
rag_pipeline_dataset_create_entity.partial_member_list,
)
@ -81,10 +80,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
name="",
description="",

View File

@ -23,7 +23,7 @@ from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models.account import Account
from models import Account
from models.dataset import Pipeline
from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService

View File

@ -1,6 +1,3 @@
from typing import cast
from flask_login import current_user # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@ -13,8 +10,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@ -28,7 +24,8 @@ class RagPipelineImportApi(Resource):
@marshal_with(pipeline_import_fields)
def post(self):
# Check user role first
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -47,7 +44,7 @@ class RagPipelineImportApi(Resource):
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Import app
account = cast(Account, current_user)
account = current_user
result = import_service.import_rag_pipeline(
account=account,
import_mode=args["mode"],
@ -74,15 +71,16 @@ class RagPipelineImportConfirmApi(Resource):
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self, import_id):
current_user, _ = current_account_with_tenant()
# Check user role first
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = cast(Account, current_user)
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
@ -100,7 +98,8 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@account_initialization_required
@marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
with Session(db.engine) as session:
@ -117,7 +116,8 @@ class RagPipelineExportApi(Resource):
@get_rag_pipeline
@account_initialization_required
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Add include_secret params

View File

@ -18,6 +18,7 @@ from controllers.console.app.error import (
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
@ -36,8 +37,8 @@ from fields.workflow_run_fields import (
)
from libs import helper
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
from models.model import EndUser
from services.errors.app import WorkflowHashNotEqualError
@ -56,15 +57,12 @@ class DraftRagPipelineApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_fields)
def get(self, pipeline: Pipeline):
"""
Get draft rag pipeline's workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
@ -79,13 +77,13 @@ class DraftRagPipelineApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline):
"""
Sync draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
content_type = request.headers.get("Content-Type", "")
@ -154,13 +152,13 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline, node_id: str):
"""
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
@ -194,7 +192,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -229,7 +228,8 @@ class DraftRagPipelineRunApi(Resource):
Run draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -264,7 +264,8 @@ class PublishedRagPipelineRunApi(Resource):
Run published workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -303,7 +304,7 @@ class PublishedRagPipelineRunApi(Resource):
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.is_editor:
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
@ -344,7 +345,7 @@ class PublishedRagPipelineRunApi(Resource):
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.is_editor:
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
@ -385,7 +386,8 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -428,7 +430,8 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -472,7 +475,8 @@ class RagPipelineDraftNodeRunApi(Resource):
Run draft workflow node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -505,7 +509,8 @@ class RagPipelineTaskStopApi(Resource):
Stop workflow task
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
@ -525,7 +530,8 @@ class PublishedRagPipelineApi(Resource):
Get published pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
if not pipeline.is_published:
return None
@ -545,7 +551,8 @@ class PublishedRagPipelineApi(Resource):
Publish workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
rag_pipeline_service = RagPipelineService()
@ -580,7 +587,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
Get default block config
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs
@ -599,7 +607,8 @@ class DefaultRagPipelineBlockConfigApi(Resource):
Get default block config
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -631,7 +640,8 @@ class PublishedAllRagPipelineApi(Resource):
"""
Get published workflows
"""
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -681,7 +691,8 @@ class RagPipelineByIdApi(Resource):
Update workflow attributes
"""
# Check permission
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -733,13 +744,11 @@ class PublishedRagPipelineSecondStepApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline):
"""
Get second step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
@ -759,13 +768,11 @@ class PublishedRagPipelineFirstStepApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline):
"""
Get first step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
@ -785,13 +792,11 @@ class DraftRagPipelineFirstStepApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline):
"""
Get first step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
@ -811,13 +816,11 @@ class DraftRagPipelineSecondStepApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
def get(self, pipeline: Pipeline):
"""
Get second step parameters of rag pipeline
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, location="args")
args = parser.parse_args()
@ -880,7 +883,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_list_fields)
def get(self, pipeline: Pipeline, run_id):
def get(self, pipeline: Pipeline, run_id: str):
"""
Get workflow run node execution list
"""
@ -903,14 +906,8 @@ class DatasourceListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
if not isinstance(user, Account):
raise Forbidden()
tenant_id = user.current_tenant_id
if not tenant_id:
raise Forbidden()
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(current_tenant_id))
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
@ -940,11 +937,11 @@ class RagPipelineTransformApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
if not isinstance(current_user, Account):
raise Forbidden()
@edit_permission_required
def post(self, dataset_id: str):
current_user, _ = current_account_with_tenant()
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
if not current_user.is_dataset_operator:
raise Forbidden()
dataset_id = str(dataset_id)
@ -959,14 +956,13 @@ class RagPipelineDatasourceVariableApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_run_node_execution_fields)
def post(self, pipeline: Pipeline):
"""
Set datasource variables
"""
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info", type=dict, required=True, location="json")

View File

@ -3,8 +3,7 @@ from functools import wraps
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models.account import Account
from libs.login import current_account_with_tenant
from models.dataset import Pipeline
@ -17,8 +16,7 @@ def get_rag_pipeline(
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")
if not isinstance(current_user, Account):
raise ValueError("current_user is not an account")
_, current_tenant_id = current_account_with_tenant()
pipeline_id = kwargs.get("pipeline_id")
pipeline_id = str(pipeline_id)
@ -27,7 +25,7 @@ def get_rag_pipeline(
pipeline = (
db.session.query(Pipeline)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
.first()
)

View File

@ -12,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import current_user, login_required
from models import Account, App, InstalledApp, RecommendedApp
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@ -29,9 +29,7 @@ class InstalledAppsListApi(Resource):
@marshal_with(installed_app_list_fields)
def get(self):
app_id = request.args.get("app_id", default=None, type=str)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
current_user, current_tenant_id = current_account_with_tenant()
if app_id:
installed_apps = db.session.scalars(
@ -121,9 +119,8 @@ class InstalledAppsListApi(Resource):
if recommended_app is None:
raise NotFound("App not found")
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
current_tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
app = db.session.query(App).where(App.id == args["app_id"]).first()
if app is None:
@ -163,9 +160,8 @@ class InstalledAppApi(InstalledAppResource):
"""
def delete(self, installed_app):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
_, current_tenant_id = current_account_with_tenant()
if installed_app.app_owner_tenant_id == current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant")
db.session.delete(installed_app)

View File

@ -23,8 +23,7 @@ from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from libs.login import current_account_with_tenant
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@ -48,6 +47,7 @@ logger = logging.getLogger(__name__)
class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
@ -61,8 +61,6 @@ class MessageListApi(InstalledAppResource):
args = parser.parse_args()
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
)
@ -78,6 +76,7 @@ class MessageListApi(InstalledAppResource):
)
class MessageFeedbackApi(InstalledAppResource):
def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
message_id = str(message_id)
@ -88,8 +87,6 @@ class MessageFeedbackApi(InstalledAppResource):
args = parser.parse_args()
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
@ -109,6 +106,7 @@ class MessageFeedbackApi(InstalledAppResource):
)
class MessageMoreLikeThisApi(InstalledAppResource):
def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -124,8 +122,6 @@ class MessageMoreLikeThisApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming"
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate_more_like_this(
app_model=app_model,
user=current_user,
@ -159,6 +155,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
)
class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -167,8 +164,6 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
message_id = str(message_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)

View File

@ -7,8 +7,7 @@ from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from libs.login import current_user
from models import Account
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
@ -35,6 +34,7 @@ class SavedMessageListApi(InstalledAppResource):
@marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -44,11 +44,10 @@ class SavedMessageListApi(InstalledAppResource):
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
def post(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -58,8 +57,6 @@ class SavedMessageListApi(InstalledAppResource):
args = parser.parse_args()
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.save(app_model, current_user, args["message_id"])
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@ -72,6 +69,7 @@ class SavedMessageListApi(InstalledAppResource):
)
class SavedMessageApi(InstalledAppResource):
def delete(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
message_id = str(message_id)
@ -79,8 +77,6 @@ class SavedMessageApi(InstalledAppResource):
if app_model.mode != "completion":
raise NotCompletionAppError()
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204

View File

@ -8,9 +8,8 @@ from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import current_user, login_required
from libs.login import current_account_with_tenant, login_required
from models import InstalledApp
from models.account import Account
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -24,13 +23,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
installed_app = (
db.session.query(InstalledApp)
.where(
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
)
.where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
.first()
)
@ -56,9 +52,9 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled:
assert isinstance(current_user, Account)
app_id = installed_app.app_id
app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(

View File

@ -4,8 +4,7 @@ from constants import HIDDEN_VALUE
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
@ -47,9 +46,7 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@api.doc("create_api_based_extension")
@ -70,16 +67,16 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("api_endpoint", type=str, required=True, location="json")
parser.add_argument("api_key", type=str, required=True, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
name=args["name"],
api_endpoint=args["api_endpoint"],
api_key=args["api_key"],
@ -99,10 +96,8 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
@ -125,12 +120,10 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
@ -154,12 +147,10 @@ class APIBasedExtensionDetailAPI(Resource):
@login_required
@account_initialization_required
def delete(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
APIBasedExtensionService.delete(extension_data_from_db)

View File

@ -1,7 +1,6 @@
from flask_restx import Resource, fields
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.feature_service import FeatureService
from . import api, console_ns
@ -23,9 +22,9 @@ class FeatureApi(Resource):
@cloud_utm_record
def get(self):
"""Get feature configuration for current tenant"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
_, current_tenant_id = current_account_with_tenant()
return FeatureService.get_features(current_tenant_id).model_dump()
@console_ns.route("/system-features")

View File

@ -1,7 +1,6 @@
from typing import Literal
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with
from werkzeug.exceptions import Forbidden
@ -22,8 +21,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.file_service import FileService
from . import console_ns
@ -53,6 +51,7 @@ class FileApi(Resource):
@marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents")
def post(self):
current_user, _ = current_account_with_tenant()
source_str = request.form.get("source")
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
@ -65,16 +64,12 @@ class FileApi(Resource):
if not file.filename:
raise FilenameNotExistsError
if source == "datasets" and not current_user.is_dataset_editor:
raise Forbidden()
if source not in ("datasets", None):
source = None
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
@ -108,4 +103,4 @@ class FileSupportTypeApi(Resource):
@login_required
@account_initialization_required
def get(self):
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
return {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}

View File

@ -14,8 +14,7 @@ from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from libs.login import current_user
from models.account import Account
from libs.login import current_account_with_tenant
from services.file_service import FileService
from . import console_ns
@ -64,8 +63,7 @@ class RemoteFileUploadApi(Resource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
assert isinstance(current_user, Account)
user = current_user
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,

View File

@ -5,8 +5,7 @@ from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models.model import Tag
from services.tag_service import TagService
@ -24,11 +23,10 @@ class TagListApi(Resource):
@account_initialization_required
@marshal_with(dataset_tag_fields)
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_, current_tenant_id = current_account_with_tenant()
tag_type = request.args.get("type", type=str, default="")
keyword = request.args.get("keyword", default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
tags = TagService.get_tags(tag_type, current_tenant_id, keyword)
return tags, 200
@ -36,8 +34,7 @@ class TagListApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@ -63,8 +60,7 @@ class TagUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
def patch(self, tag_id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, _ = current_account_with_tenant()
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -87,8 +83,7 @@ class TagUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
def delete(self, tag_id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, _ = current_account_with_tenant()
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
@ -105,8 +100,7 @@ class TagBindingCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@ -133,8 +127,7 @@ class TagBindingDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()

View File

@ -2,11 +2,11 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask_login import current_user
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.account import TenantPluginPermission
P = ParamSpec("P")
@ -20,8 +20,9 @@ def plugin_permission_required(
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
tenant_id = user.current_tenant_id
tenant_id = current_tenant_id
with Session(db.engine) as session:
permission = (

View File

@ -2,7 +2,6 @@ from datetime import datetime
import pytz
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -37,9 +36,8 @@ from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, email, extract_remote_ip, timezone
from libs.login import login_required
from models import AccountIntegrate, InvitationCode
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from models import Account, AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@ -50,9 +48,7 @@ class AccountInitApi(Resource):
@setup_required
@login_required
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
if account.status == "active":
raise AccountAlreadyInitedError()
@ -106,8 +102,7 @@ class AccountProfileApi(Resource):
@marshal_with(account_fields)
@enterprise_license_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
return current_user
@ -118,8 +113,7 @@ class AccountNameApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
@ -140,8 +134,7 @@ class AccountAvatarApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("avatar", type=str, required=True, location="json")
args = parser.parse_args()
@ -158,8 +151,7 @@ class AccountInterfaceLanguageApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("interface_language", type=supported_language, required=True, location="json")
args = parser.parse_args()
@ -176,8 +168,7 @@ class AccountInterfaceThemeApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json")
args = parser.parse_args()
@ -194,8 +185,7 @@ class AccountTimezoneApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("timezone", type=str, required=True, location="json")
args = parser.parse_args()
@ -216,8 +206,7 @@ class AccountPasswordApi(Resource):
@account_initialization_required
@marshal_with(account_fields)
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("password", type=str, required=False, location="json")
parser.add_argument("new_password", type=str, required=True, location="json")
@ -253,9 +242,7 @@ class AccountIntegrateApi(Resource):
@account_initialization_required
@marshal_with(integrate_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
account_integrates = db.session.scalars(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
@ -298,9 +285,7 @@ class AccountDeleteVerifyApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
token, code = AccountService.generate_account_deletion_verification_code(account)
AccountService.send_account_deletion_verification_email(account, code)
@ -314,9 +299,7 @@ class AccountDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, location="json")
@ -358,9 +341,7 @@ class EducationVerifyApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(verify_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
return BillingService.EducationIdentity.verify(account.id, account.email)
@ -380,9 +361,7 @@ class EducationApi(Resource):
@only_edition_cloud
@cloud_edition_billing_enabled
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, location="json")
@ -399,9 +378,7 @@ class EducationApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(status_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
account = current_user
account, _ = current_account_with_tenant()
res = BillingService.EducationIdentity.status(account.id)
# convert expire_at to UTC timestamp from isoformat
@ -441,6 +418,7 @@ class ChangeEmailSendEmailApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
@ -467,8 +445,6 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if user_email != current_user.email:
raise InvalidEmailError()
else:
@ -551,8 +527,7 @@ class ChangeEmailResetApi(Resource):
AccountService.revoke_change_email_token(args["token"])
old_email = reset_data.get("old_email", "")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if current_user.email != old_email:
raise AccountNotFound()

View File

@ -3,8 +3,7 @@ from flask_restx import Resource, fields
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.agent_service import AgentService
@ -21,12 +20,11 @@ class AgentProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
assert isinstance(current_user, Account)
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
assert user.current_tenant_id is not None
user_id = user.id
tenant_id = user.current_tenant_id
tenant_id = current_tenant_id
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
@ -45,9 +43,5 @@ class AgentProviderApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_name: str):
assert isinstance(current_user, Account)
user = current_user
assert user.current_tenant_id is not None
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
current_user, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))

View File

@ -5,18 +5,10 @@ from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.plugin.endpoint_service import EndpointService
def _current_account_with_tenant() -> tuple[Account, str]:
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
return current_user, tenant_id
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource):
@api.doc("create_endpoint")
@ -41,7 +33,7 @@ class EndpointCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
@ -87,7 +79,7 @@ class EndpointListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@ -130,7 +122,7 @@ class EndpointListForSinglePluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@ -172,7 +164,7 @@ class EndpointDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -212,7 +204,7 @@ class EndpointUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -255,7 +247,7 @@ class EndpointEnableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -288,7 +280,7 @@ class EndpointDisableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)

View File

@ -5,8 +5,8 @@ from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_user, login_required
from models.account import Account, TenantAccountRole
from libs.login import current_account_with_tenant, login_required
from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
@ -18,12 +18,11 @@ class LoadBalancingCredentialsValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
assert isinstance(current_user, Account)
current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@ -72,12 +71,11 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str, config_id: str):
assert isinstance(current_user, Account)
current_user, current_tenant_id = current_account_with_tenant()
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")

View File

@ -25,7 +25,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
from libs.helper import extract_remote_ip
from libs.login import current_user, login_required
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
@ -41,8 +41,7 @@ class MemberListApi(Resource):
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
@ -69,9 +68,7 @@ class MemberInviteEmailApi(Resource):
interface_language = args["language"]
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
@ -120,8 +117,7 @@ class MemberCancelInviteApi(Resource):
@login_required
@account_initialization_required
def delete(self, member_id):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.query(Account).where(Account.id == str(member_id)).first()
@ -160,9 +156,7 @@ class MemberUpdateRoleApi(Resource):
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id))
@ -189,8 +183,7 @@ class DatasetOperatorMemberListApi(Resource):
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
@ -212,10 +205,8 @@ class SendOwnerTransferEmailApi(Resource):
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
current_user, _ = current_account_with_tenant()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@ -250,8 +241,7 @@ class OwnerTransferCheckApi(Resource):
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@ -296,8 +286,7 @@ class OwnerTransfer(Resource):
args = parser.parse_args()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):

View File

@ -1,7 +1,6 @@
import io
from flask import send_file
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@ -11,8 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@ -23,11 +21,8 @@ class ModelProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument(
@ -52,11 +47,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def get(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
@ -73,8 +65,7 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@ -85,11 +76,9 @@ class ModelProviderCredentialApi(Resource):
model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try:
model_provider_service.create_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_name=args["name"],
@ -103,8 +92,7 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def put(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@ -116,11 +104,9 @@ class ModelProviderCredentialApi(Resource):
model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try:
model_provider_service.update_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_id=args["credential_id"],
@ -135,19 +121,16 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
)
return {"result": "success"}, 204
@ -159,19 +142,16 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
service = ModelProviderService()
service.switch_active_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credential_id=args["credential_id"],
)
@ -184,15 +164,12 @@ class ModelProviderValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
model_provider_service = ModelProviderService()
@ -240,14 +217,11 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument(
@ -276,14 +250,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
def get(self, provider: str):
if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
data = BillingService.get_model_provider_payment_link(
provider_name=provider,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
account_id=current_user.id,
prefilled_email=current_user.email,
)

View File

@ -1,6 +1,5 @@
import logging
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@ -10,7 +9,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
@ -23,6 +22,8 @@ class DefaultModelApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument(
"model_type",
@ -34,8 +35,6 @@ class DefaultModelApi(Resource):
)
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id, model_type=args["model_type"]
@ -47,15 +46,14 @@ class DefaultModelApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
model_settings = args["model_settings"]
for model_setting in model_settings:
@ -92,7 +90,7 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
@ -104,11 +102,11 @@ class ModelProviderModelApi(Resource):
@account_initialization_required
def post(self, provider: str):
# To save the model's load balance configs
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
@ -129,7 +127,7 @@ class ModelProviderModelApi(Resource):
raise ValueError("credential_id is required when configuring a custom-model")
service = ModelProviderService()
service.switch_active_custom_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@ -164,11 +162,11 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
@ -195,7 +193,7 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def get(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
@ -257,6 +255,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@ -274,7 +274,6 @@ class ModelProviderModelCredentialApi(Resource):
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
try:
@ -301,6 +300,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def put(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@ -323,7 +324,7 @@ class ModelProviderModelCredentialApi(Resource):
try:
model_provider_service.update_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@ -340,6 +341,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
@ -357,7 +360,7 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@ -373,6 +376,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
@ -390,7 +395,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
service = ModelProviderService()
service.add_model_credential_to_model_list(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@ -407,7 +412,7 @@ class ModelProviderModelEnableApi(Resource):
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@ -437,7 +442,7 @@ class ModelProviderModelDisableApi(Resource):
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@ -465,7 +470,7 @@ class ModelProviderModelValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@ -514,8 +519,7 @@ class ModelProviderModelParameterRuleApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
@ -531,8 +535,7 @@ class ModelProviderAvailableModelApi(Resource):
@login_required
@account_initialization_required
def get(self, model_type):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)

View File

@ -1,7 +1,6 @@
import io
from flask import request, send_file
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@ -11,7 +10,7 @@ from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService
@ -26,7 +25,7 @@ class PluginDebuggingKeyApi(Resource):
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {
@ -44,7 +43,7 @@ class PluginListApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=False, location="args", default=1)
parser.add_argument("page_size", type=int, required=False, location="args", default=256)
@ -81,7 +80,7 @@ class PluginListInstallationsFromIdsApi(Resource):
@login_required
@account_initialization_required
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_ids", type=list, required=True, location="json")
@ -120,7 +119,7 @@ class PluginUploadFromPkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
file = request.files["pkg"]
@ -144,7 +143,7 @@ class PluginUploadFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
@ -167,7 +166,7 @@ class PluginUploadFromBundleApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
file = request.files["bundle"]
@ -191,7 +190,7 @@ class PluginInstallFromPkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@ -217,7 +216,7 @@ class PluginInstallFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
@ -247,7 +246,7 @@ class PluginInstallFromMarketplaceApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@ -273,7 +272,7 @@ class PluginFetchMarketplacePkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
@ -299,7 +298,7 @@ class PluginFetchManifestApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
@ -324,7 +323,7 @@ class PluginFetchInstallTasksApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@ -346,7 +345,7 @@ class PluginFetchInstallTaskApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self, task_id: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
@ -361,7 +360,7 @@ class PluginDeleteInstallTaskApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self, task_id: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
@ -376,7 +375,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
@ -391,7 +390,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
@ -406,7 +405,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@ -430,7 +429,7 @@ class PluginUpgradeFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@ -466,7 +465,7 @@ class PluginUninstallApi(Resource):
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args()
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
@ -480,6 +479,7 @@ class PluginChangePermissionApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
@ -492,7 +492,7 @@ class PluginChangePermissionApi(Resource):
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
tenant_id = user.current_tenant_id
tenant_id = current_tenant_id
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
@ -503,7 +503,7 @@ class PluginFetchPermissionApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id)
if not permission:
@ -529,10 +529,10 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
@account_initialization_required
def get(self):
# check if the user is admin or owner
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
user_id = current_user.id
parser = reqparse.RequestParser()
@ -565,7 +565,7 @@ class PluginChangePreferencesApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
@ -574,8 +574,6 @@ class PluginChangePreferencesApi(Resource):
req.add_argument("auto_upgrade", type=dict, required=True, location="json")
args = req.parse_args()
tenant_id = user.current_tenant_id
permission = args["permission"]
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
@ -621,7 +619,7 @@ class PluginFetchPreferencesApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id)
permission_dict = {
@ -661,7 +659,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
@account_initialization_required
def post(self):
# exclude one single plugin
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
req = reqparse.RequestParser()
req.add_argument("plugin_id", type=str, required=True, location="json")

View File

@ -2,7 +2,6 @@ import io
from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file
from flask_login import current_user
from flask_restx import (
Resource,
reqparse,
@ -24,7 +23,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
@ -53,10 +52,9 @@ class ToolProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument(
@ -78,9 +76,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.list_builtin_tool_provider_tools(
@ -96,9 +92,7 @@ class ToolBuiltinProviderInfoApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@ -109,11 +103,10 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = req.parse_args()
@ -131,10 +124,9 @@ class ToolBuiltinProviderAddApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -161,13 +153,12 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
@ -193,7 +184,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials(
@ -218,13 +209,12 @@ class ToolApiProviderAddApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -258,10 +248,9 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@ -282,10 +271,9 @@ class ToolApiProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@ -308,13 +296,12 @@ class ToolApiProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@ -350,13 +337,12 @@ class ToolApiProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@ -377,10 +363,9 @@ class ToolApiProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@ -401,8 +386,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider, credential_type):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.list_builtin_provider_credentials_schema(
@ -444,9 +428,9 @@ class ToolApiProviderPreviousTestApi(Resource):
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
current_tenant_id,
args["provider_name"] or "",
args["tool_name"],
args["credentials"],
@ -462,13 +446,12 @@ class ToolWorkflowProviderCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
@ -502,13 +485,12 @@ class ToolWorkflowProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@ -545,13 +527,12 @@ class ToolWorkflowProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@ -571,10 +552,9 @@ class ToolWorkflowProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
@ -606,10 +586,9 @@ class ToolWorkflowProviderListToolApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
@ -631,10 +610,9 @@ class ToolBuiltinListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
@ -653,8 +631,7 @@ class ToolApiListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
[
@ -672,10 +649,9 @@ class ToolWorkflowListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
@ -709,19 +685,18 @@ class ToolPluginOAuthApi(Resource):
provider_name = tool_provider.provider_name
# todo check permission
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
user_id=user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url(
@ -800,11 +775,12 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
)
@ -819,13 +795,13 @@ class ToolOAuthCustomClient(Resource):
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
provider=provider,
client_params=args.get("client_params", {}),
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
@ -835,20 +811,18 @@ class ToolOAuthCustomClient(Resource):
@login_required
@account_initialization_required
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
)
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
)
@ -858,9 +832,10 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
tenant_id=current_user.current_tenant_id, provider_name=provider
tenant_id=current_tenant_id, provider_name=provider
)
)
@ -871,7 +846,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
@ -900,12 +875,12 @@ class ToolProviderMCPApi(Resource):
)
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
args = parser.parse_args()
user = current_user
user, tenant_id = current_account_with_tenant()
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
@ -940,8 +915,9 @@ class ToolProviderMCPApi(Resource):
pass
else:
raise ValueError("Server URL is not valid.")
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.update_mcp_provider(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
@ -962,7 +938,8 @@ class ToolProviderMCPApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
@ -977,7 +954,7 @@ class ToolMCPAuthApi(Resource):
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
provider_id = args["provider_id"]
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
@ -1018,8 +995,8 @@ class ToolMCPDetailApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_id):
user = current_user
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@ -1029,8 +1006,7 @@ class ToolMCPListAllApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
@ -1043,7 +1019,7 @@ class ToolMCPUpdateApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_id):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id,
provider_id=provider_id,

View File

@ -23,8 +23,8 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_user, login_required
from models.account import Account, Tenant, TenantStatus
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.feature_service import FeatureService
from services.file_service import FileService
@ -70,8 +70,7 @@ class TenantListApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []
@ -85,7 +84,7 @@ class TenantListApi(Resource):
"status": tenant.status,
"created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False,
"current": tenant.id == current_tenant_id if current_tenant_id else False,
}
tenant_dicts.append(tenant_dict)
@ -130,8 +129,7 @@ class TenantApi(Resource):
if request.path == "/info":
logger.warning("Deprecated URL /info was used.")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
tenant = current_user.current_tenant
if not tenant:
raise ValueError("No current tenant")
@ -155,8 +153,7 @@ class SwitchWorkspaceApi(Resource):
@login_required
@account_initialization_required
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json")
args = parser.parse_args()
@ -181,16 +178,12 @@ class CustomConfigWorkspaceApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("remove_webapp_brand", type=bool, location="json")
parser.add_argument("replace_webapp_logo", type=str, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
tenant = db.get_or_404(Tenant, current_tenant_id)
custom_config_dict = {
"remove_webapp_brand": args["remove_webapp_brand"],
@ -212,8 +205,7 @@ class WebappLogoWorkspaceApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -253,15 +245,14 @@ class WorkspaceInfoApi(Resource):
@account_initialization_required
# Change workspace name
def post(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
if not current_tenant_id:
raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
tenant = db.get_or_404(Tenant, current_tenant_id)
tenant.name = args["name"]
db.session.commit()

View File

@ -12,8 +12,8 @@ from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import current_user
from models.account import Account, AccountStatus
from libs.login import current_account_with_tenant
from models.account import AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
@ -25,18 +25,12 @@ P = ParamSpec("P")
R = TypeVar("R")
def _current_account() -> Account:
assert isinstance(current_user, Account)
return current_user
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization
account = _current_account()
if account.status == AccountStatus.UNINITIALIZED:
current_user, _ = current_account_with_tenant()
if current_user.status == AccountStatus.UNINITIALIZED:
raise AccountNotInitializedError()
return view(*args, **kwargs)
@ -80,9 +74,8 @@ def only_edition_self_hosted(view: Callable[P, R]):
def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
@ -94,10 +87,8 @@ def cloud_edition_billing_resource_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
members = features.members
apps = features.apps
@ -138,9 +129,8 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
@ -163,13 +153,11 @@ def cloud_edition_billing_rate_limit_check(resource: str):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge":
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
_, current_tenant_id = current_account_with_tenant()
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{tenant_id}"
key = f"rate_limit_{current_tenant_id}"
redis_client.zadd(key, {current_time: current_time})
@ -180,7 +168,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=tenant_id,
tenant_id=current_tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
@ -200,17 +188,15 @@ def cloud_utm_record(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(tenant_id, utm_info_dict)
OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs)
@ -260,9 +246,9 @@ def email_password_login_enabled(view: Callable[P, R]):
return decorated
def email_register_enabled(view):
def email_register_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
if features.is_allow_register:
return view(*args, **kwargs)
@ -289,9 +275,8 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.is_allow_transfer_workspace:
return view(*args, **kwargs)
@ -301,14 +286,26 @@ def is_allow_transfer_owner(view: Callable[P, R]):
return decorated
def knowledge_pipeline_publish_enabled(view):
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.knowledge_pipeline.publish_enabled:
return view(*args, **kwargs)
abort(403)
return decorated
def edit_permission_required(f: Callable[P, R]):
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)
return decorated_function

View File

@ -17,7 +17,7 @@ class BaseMail(Resource):
def post(self):
args = _mail_parser.parse_args()
send_inner_email_task.delay(
send_inner_email_task.delay( # type: ignore
to=args["to"],
subject=args["subject"],
body=args["body"],

View File

@ -31,7 +31,7 @@ from core.plugin.entities.request import (
)
from core.tools.entities.tool_entities import ToolProviderType
from libs.helper import length_prefixed_response
from models.account import Account, Tenant
from models import Account, Tenant
from models.model import EndUser

View File

@ -7,7 +7,7 @@ from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import enterprise_inner_api_only
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from models.account import Account
from models import Account
from services.account_service import TenantService

View File

@ -10,7 +10,7 @@ from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_fields, build_annotation_model
from libs.login import current_user
from models.account import Account
from models import Account
from models.model import App
from services.annotation_service import AppAnnotationService

View File

@ -17,7 +17,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from libs import helper
from libs.login import current_user
from models.account import Account
from models import Account
from models.dataset import Pipeline
from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError

View File

@ -1,5 +1,4 @@
from flask import request
from flask_login import current_user
from flask_restx import marshal, reqparse
from werkzeug.exceptions import NotFound
@ -16,6 +15,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import current_account_with_tenant
from models.dataset import Dataset
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
@ -66,6 +66,7 @@ class SegmentApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Create single segment."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -84,7 +85,7 @@ class SegmentApi(DatasetApiResource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -117,6 +118,7 @@ class SegmentApi(DatasetApiResource):
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Get segments."""
# check dataset
page = request.args.get("page", default=1, type=int)
@ -133,7 +135,7 @@ class SegmentApi(DatasetApiResource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -149,7 +151,7 @@ class SegmentApi(DatasetApiResource):
segments, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
status_list=args["status"],
keyword=args["keyword"],
page=page,
@ -184,6 +186,7 @@ class DatasetSegmentApi(DatasetApiResource):
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@ -195,7 +198,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not document:
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
@ -217,6 +220,7 @@ class DatasetSegmentApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@ -232,7 +236,7 @@ class DatasetSegmentApi(DatasetApiResource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -244,7 +248,7 @@ class DatasetSegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -266,6 +270,7 @@ class DatasetSegmentApi(DatasetApiResource):
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
@ -277,7 +282,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not document:
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -307,6 +312,7 @@ class ChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Create child chunk."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -319,7 +325,7 @@ class ChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -328,7 +334,7 @@ class ChildChunkApi(DatasetApiResource):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
@ -364,6 +370,7 @@ class ChildChunkApi(DatasetApiResource):
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Get child chunks."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -376,7 +383,7 @@ class ChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -423,6 +430,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Delete child chunk."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -435,7 +443,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -444,9 +452,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# check child chunk
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
)
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
if not child_chunk:
raise NotFound("Child chunk not found.")
@ -483,6 +489,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Update child chunk."""
# check dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -495,7 +502,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Document not found.")
# get segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -504,9 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Segment not found.")
# get child chunk
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
)
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
if not child_chunk:
raise NotFound("Child chunk not found.")

View File

@ -17,7 +17,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App, DefaultEndUserSessionID, EndUser
from services.feature_service import FeatureService

View File

@ -20,7 +20,7 @@ from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from models import Account
from services.account_service import AccountService

View File

@ -70,8 +70,7 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Conversation, EndUser, Message, MessageFile
from models.account import Account
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.workflow import Workflow

View File

@ -23,7 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models import Account
from models.model import App, EndUser
from services.conversation_service import ConversationService

View File

@ -112,7 +112,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:

View File

@ -207,6 +207,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
app_mode=app_config.app_mode,
)
db.session.add(message)

View File

@ -61,7 +61,7 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from models.account import Account
from models import Account
from models.enums import CreatorUserRole
from models.model import EndUser
from models.workflow import (

View File

@ -472,6 +472,9 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
else:
# some historical data may have a provider record but not be set as valid
provider_record.is_valid = True
session.commit()
except Exception:
@ -1145,6 +1148,15 @@ class ProviderConfiguration(BaseModel):
raise ValueError("Can't add same credential")
provider_model_record.credential_id = credential_record.id
provider_model_record.updated_at = naive_utc_now()
# clear cache
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL,
)
provider_model_credentials_cache.delete()
session.add(provider_model_record)
session.commit()
@ -1178,6 +1190,14 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record)
session.commit()
# clear cache
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL,
)
provider_model_credentials_cache.delete()
def delete_custom_model(self, model_type: ModelType, model: str):
"""
Delete custom model.

View File

@ -7,7 +7,7 @@ import uuid
from collections import deque
from collections.abc import Sequence
from datetime import datetime
from typing import Final
from typing import Final, cast
from urllib.parse import urljoin
import httpx
@ -199,7 +199,7 @@ def convert_to_trace_id(uuid_v4: str | None) -> int:
raise ValueError("UUID cannot be None")
try:
uuid_obj = uuid.UUID(uuid_v4)
return uuid_obj.int
return cast(int, uuid_obj.int)
except ValueError as e:
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e

View File

@ -13,6 +13,7 @@ class TracingProviderEnum(StrEnum):
OPIK = "opik"
WEAVE = "weave"
ALIYUN = "aliyun"
TENCENT = "tencent"
class BaseTracingConfig(BaseModel):
@ -195,5 +196,32 @@ class AliyunConfig(BaseTracingConfig):
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
class TencentConfig(BaseTracingConfig):
"""
Tencent APM tracing config
"""
token: str
endpoint: str
service_name: str
@field_validator("token")
@classmethod
def token_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("Token cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
@field_validator("service_name")
@classmethod
def service_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@ -90,6 +90,7 @@ class SuggestedQuestionTraceInfo(BaseTraceInfo):
class DatasetRetrievalTraceInfo(BaseTraceInfo):
documents: Any = None
error: str | None = None
class ToolTraceInfo(BaseTraceInfo):

View File

@ -120,6 +120,17 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
"trace_instance": AliyunDataTrace,
}
case TracingProviderEnum.TENCENT:
from core.ops.entities.config_entity import TencentConfig
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
return {
"config_class": TencentConfig,
"secret_keys": ["token"],
"other_keys": ["endpoint", "service_name"],
"trace_instance": TencentDataTrace,
}
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
@ -723,6 +734,7 @@ class TraceTask:
end_time=timer.get("end"),
metadata=metadata,
message_data=message_data.to_dict(),
error=kwargs.get("error"),
)
return dataset_retrieval_trace_info
@ -889,6 +901,7 @@ class TraceQueueManager:
continue
file_id = uuid4().hex
trace_info = task.execute()
task_data = TaskData(
app_id=task.app_id,
trace_info_type=type(trace_info).__name__,
@ -900,4 +913,4 @@ class TraceQueueManager:
"file_id": file_id,
"app_id": task.app_id,
}
process_trace_tasks.delay(file_info)
process_trace_tasks.delay(file_info) # type: ignore

View File

View File

@ -0,0 +1,337 @@
"""
Tencent APM Trace Client - handles network operations, metrics, and API communication
"""
from __future__ import annotations
import importlib
import logging
import os
import socket
from typing import TYPE_CHECKING
from urllib.parse import urlparse
if TYPE_CHECKING:
from opentelemetry.metrics import Meter
from opentelemetry.metrics._internal.instrument import Histogram
from opentelemetry.sdk.metrics.export import MetricReader
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import SpanKind
from opentelemetry.util.types import AttributeValue
from configs import dify_config
from .entities.tencent_semconv import LLM_OPERATION_DURATION
from .entities.tencent_trace_entity import SpanData
logger = logging.getLogger(__name__)
class TencentTraceClient:
"""Tencent APM trace client using OpenTelemetry OTLP exporter"""
def __init__(
self,
service_name: str,
endpoint: str,
token: str,
max_queue_size: int = 1000,
schedule_delay_sec: int = 5,
max_export_batch_size: int = 50,
metrics_export_interval_sec: int = 10,
):
self.endpoint = endpoint
self.token = token
self.service_name = service_name
self.metrics_export_interval_sec = metrics_export_interval_sec
self.resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
}
)
# Prepare gRPC endpoint/metadata
grpc_endpoint, insecure, _, _ = self._resolve_grpc_target(endpoint)
headers = (("authorization", f"Bearer {token}"),)
self.exporter = OTLPSpanExporter(
endpoint=grpc_endpoint,
headers=headers,
insecure=insecure,
timeout=30,
)
self.tracer_provider = TracerProvider(resource=self.resource)
self.span_processor = BatchSpanProcessor(
span_exporter=self.exporter,
max_queue_size=max_queue_size,
schedule_delay_millis=schedule_delay_sec * 1000,
max_export_batch_size=max_export_batch_size,
)
self.tracer_provider.add_span_processor(self.span_processor)
self.tracer = self.tracer_provider.get_tracer("dify.tencent_apm")
# Store span contexts for parent-child relationships
self.span_contexts: dict[int, trace_api.SpanContext] = {}
self.meter: Meter | None = None
self.hist_llm_duration: Histogram | None = None
self.metric_reader: MetricReader | None = None
# Metrics exporter and instruments
try:
from opentelemetry import metrics
from opentelemetry.sdk.metrics import Histogram, MeterProvider
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "").strip().lower()
use_http_protobuf = protocol in {"http/protobuf", "http-protobuf"}
use_http_json = protocol in {"http/json", "http-json"}
# Set preferred temporality for histograms to DELTA
preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA}
def _create_metric_exporter(exporter_cls, **kwargs):
"""Create metric exporter with preferred_temporality support"""
try:
return exporter_cls(**kwargs, preferred_temporality=preferred_temporality)
except Exception:
return exporter_cls(**kwargs)
metric_reader = None
if use_http_json:
exporter_cls = None
for mod_path in (
"opentelemetry.exporter.otlp.http.json.metric_exporter",
"opentelemetry.exporter.otlp.json.metric_exporter",
):
try:
mod = importlib.import_module(mod_path)
exporter_cls = getattr(mod, "OTLPMetricExporter", None)
if exporter_cls:
break
except Exception:
continue
if exporter_cls is not None:
metric_exporter = _create_metric_exporter(
exporter_cls,
endpoint=endpoint,
headers={"authorization": f"Bearer {token}"},
)
else:
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as HttpMetricExporter,
)
metric_exporter = _create_metric_exporter(
HttpMetricExporter,
endpoint=endpoint,
headers={"authorization": f"Bearer {token}"},
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
)
elif use_http_protobuf:
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
OTLPMetricExporter as HttpMetricExporter,
)
metric_exporter = _create_metric_exporter(
HttpMetricExporter,
endpoint=endpoint,
headers={"authorization": f"Bearer {token}"},
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
)
else:
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import (
OTLPMetricExporter as GrpcMetricExporter,
)
m_grpc_endpoint, m_insecure, _, _ = self._resolve_grpc_target(endpoint)
metric_exporter = _create_metric_exporter(
GrpcMetricExporter,
endpoint=m_grpc_endpoint,
headers={"authorization": f"Bearer {token}"},
insecure=m_insecure,
)
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
)
if metric_reader is not None:
provider = MeterProvider(resource=self.resource, metric_readers=[metric_reader])
metrics.set_meter_provider(provider)
self.meter = metrics.get_meter("dify-sdk", dify_config.project.version)
self.hist_llm_duration = self.meter.create_histogram(
name=LLM_OPERATION_DURATION,
unit="s",
description="LLM operation duration (seconds)",
)
self.metric_reader = metric_reader
else:
self.meter = None
self.hist_llm_duration = None
self.metric_reader = None
except Exception:
logger.exception("[Tencent APM] Metrics initialization failed; metrics disabled")
self.meter = None
self.hist_llm_duration = None
self.metric_reader = None
def add_span(self, span_data: SpanData) -> None:
"""Create and export span using OpenTelemetry Tracer API"""
try:
self._create_and_export_span(span_data)
logger.debug("[Tencent APM] Created span: %s", span_data.name)
except Exception:
logger.exception("[Tencent APM] Failed to create span: %s", span_data.name)
# Metrics recording API
def record_llm_duration(self, latency_seconds: float, attributes: dict[str, str] | None = None) -> None:
"""Record LLM operation duration histogram in seconds."""
try:
if not hasattr(self, "hist_llm_duration") or self.hist_llm_duration is None:
return
attrs: dict[str, str] = {}
if attributes:
for k, v in attributes.items():
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
self.hist_llm_duration.record(latency_seconds, attrs) # type: ignore[attr-defined]
except Exception:
logger.debug("[Tencent APM] Failed to record LLM duration", exc_info=True)
def _create_and_export_span(self, span_data: SpanData) -> None:
"""Create span using OpenTelemetry Tracer API"""
try:
parent_context = None
if span_data.parent_span_id and span_data.parent_span_id in self.span_contexts:
parent_context = trace_api.set_span_in_context(
trace_api.NonRecordingSpan(self.span_contexts[span_data.parent_span_id])
)
span = self.tracer.start_span(
name=span_data.name,
context=parent_context,
kind=SpanKind.INTERNAL,
start_time=span_data.start_time,
)
self.span_contexts[span_data.span_id] = span.get_span_context()
if span_data.attributes:
attributes: dict[str, AttributeValue] = {}
for key, value in span_data.attributes.items():
if isinstance(value, (int, float, bool)):
attributes[key] = value
else:
attributes[key] = str(value)
span.set_attributes(attributes)
if span_data.events:
for event in span_data.events:
span.add_event(event.name, event.attributes, event.timestamp)
if span_data.status:
span.set_status(span_data.status)
# Manually end span; do not use context manager to avoid double-end warnings
span.end(end_time=span_data.end_time)
except Exception:
logger.exception("[Tencent APM] Error creating span: %s", span_data.name)
def api_check(self) -> bool:
"""Check API connectivity using socket connection test for gRPC endpoints"""
try:
# Resolve gRPC target consistently with exporters
_, _, host, port = self._resolve_grpc_target(self.endpoint)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)
result = sock.connect_ex((host, port))
sock.close()
if result == 0:
logger.info("[Tencent APM] Endpoint %s:%s is accessible", host, port)
return True
else:
logger.warning("[Tencent APM] Endpoint %s:%s is not accessible", host, port)
if host in ["127.0.0.1", "localhost"]:
logger.info("[Tencent APM] Development environment detected, allowing config save")
return True
return False
except Exception:
logger.exception("[Tencent APM] API check failed")
if "127.0.0.1" in self.endpoint or "localhost" in self.endpoint:
return True
return False
def get_project_url(self) -> str:
"""Get project console URL"""
return "https://console.cloud.tencent.com/apm"
def shutdown(self) -> None:
"""Shutdown the client and export remaining spans"""
try:
if self.span_processor:
logger.info("[Tencent APM] Flushing remaining spans before shutdown")
_ = self.span_processor.force_flush()
self.span_processor.shutdown()
if self.tracer_provider:
self.tracer_provider.shutdown()
if self.metric_reader is not None:
try:
self.metric_reader.shutdown() # type: ignore[attr-defined]
except Exception:
pass
except Exception:
logger.exception("[Tencent APM] Error during client shutdown")
@staticmethod
def _resolve_grpc_target(endpoint: str, default_port: int = 4317) -> tuple[str, bool, str, int]:
"""Normalize endpoint to gRPC target and security flag.
Returns:
(grpc_endpoint, insecure, host, port)
"""
try:
if endpoint.startswith(("http://", "https://")):
parsed = urlparse(endpoint)
host = parsed.hostname or "localhost"
port = parsed.port or default_port
insecure = parsed.scheme == "http"
return f"{host}:{port}", insecure, host, port
host = endpoint
port = default_port
if ":" in endpoint:
parts = endpoint.rsplit(":", 1)
host = parts[0] or "localhost"
try:
port = int(parts[1])
except Exception:
port = default_port
insecure = ("localhost" in host) or ("127.0.0.1" in host)
return f"{host}:{port}", insecure, host, port
except Exception:
host, port = "localhost", default_port
return f"{host}:{port}", True, host, port

View File

@ -0,0 +1 @@
# Tencent trace entities module

View File

@ -0,0 +1,73 @@
from enum import Enum
# public
GEN_AI_SESSION_ID = "gen_ai.session.id"
GEN_AI_USER_ID = "gen_ai.user.id"
GEN_AI_USER_NAME = "gen_ai.user.name"
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
GEN_AI_FRAMEWORK = "gen_ai.framework"
GEN_AI_IS_ENTRY = "gen_ai.is_entry" # mark to count the LLM-related traces
# Chain
INPUT_VALUE = "gen_ai.entity.input"
OUTPUT_VALUE = "gen_ai.entity.output"
# Retriever
RETRIEVAL_QUERY = "retrieval.query"
RETRIEVAL_DOCUMENT = "retrieval.document"
# GENERATION
GEN_AI_MODEL_NAME = "gen_ai.response.model"
GEN_AI_PROVIDER = "gen_ai.provider.name"
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
GEN_AI_PROMPT = "gen_ai.prompt"
GEN_AI_COMPLETION = "gen_ai.completion"
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
# Tool
TOOL_NAME = "tool.name"
TOOL_DESCRIPTION = "tool.description"
TOOL_PARAMETERS = "tool.parameters"
# Instrumentation Library
INSTRUMENTATION_NAME = "dify-sdk"
INSTRUMENTATION_VERSION = "0.1.0"
INSTRUMENTATION_LANGUAGE = "python"
# Metrics
LLM_OPERATION_DURATION = "gen_ai.client.operation.duration"
class GenAISpanKind(Enum):
WORKFLOW = "WORKFLOW" # OpenLLMetry
RETRIEVER = "RETRIEVER" # RAG
GENERATION = "GENERATION" # Langfuse
TOOL = "TOOL" # OpenLLMetry
AGENT = "AGENT" # OpenLLMetry
TASK = "TASK" # OpenLLMetry

View File

@ -0,0 +1,21 @@
from collections.abc import Sequence
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from pydantic import BaseModel, Field
class SpanData(BaseModel):
model_config = {"arbitrary_types_allowed": True}
trace_id: int = Field(..., description="The unique identifier for the trace.")
parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.")
span_id: int = Field(..., description="The unique identifier for this span.")
name: str = Field(..., description="The name of the span.")
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int = Field(..., description="The start time of the span in nanoseconds.")
end_time: int = Field(..., description="The end time of the span in nanoseconds.")

View File

@ -0,0 +1,372 @@
"""
Tencent APM Span Builder - handles all span construction logic
"""
import json
import logging
from datetime import datetime
from opentelemetry.trace import Status, StatusCode
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
MessageTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.entities.tencent_semconv import (
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_IS_ENTRY,
GEN_AI_MODEL_NAME,
GEN_AI_PROMPT,
GEN_AI_PROVIDER,
GEN_AI_RESPONSE_FINISH_REASON,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
GEN_AI_USAGE_INPUT_TOKENS,
GEN_AI_USAGE_OUTPUT_TOKENS,
GEN_AI_USAGE_TOTAL_TOKENS,
GEN_AI_USER_ID,
INPUT_VALUE,
OUTPUT_VALUE,
RETRIEVAL_DOCUMENT,
RETRIEVAL_QUERY,
TOOL_DESCRIPTION,
TOOL_NAME,
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
from core.ops.tencent_trace.utils import TencentTraceUtils
from core.rag.models.document import Document
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
logger = logging.getLogger(__name__)
class TencentSpanBuilder:
"""Builder class for constructing different types of spans"""
@staticmethod
def _get_time_nanoseconds(time_value: datetime | None) -> int:
"""Convert datetime to nanoseconds for span creation."""
return TencentTraceUtils.convert_datetime_to_nanoseconds(time_value)
@staticmethod
def build_workflow_spans(
trace_info: WorkflowTraceInfo, trace_id: int, user_id: str, links: list | None = None
) -> list[SpanData]:
"""Build workflow-related spans"""
spans = []
links = links or []
message_span_id = None
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
if hasattr(trace_info, "metadata") and trace_info.metadata.get("conversation_id"):
message_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "message")
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
if message_span_id:
message_span = TencentSpanBuilder._build_message_span(
trace_info, trace_id, message_span_id, user_id, status, links
)
spans.append(message_span)
workflow_span = TencentSpanBuilder._build_workflow_span(
trace_info, trace_id, workflow_span_id, message_span_id, user_id, status, links
)
spans.append(workflow_span)
return spans
@staticmethod
def _build_message_span(
trace_info: WorkflowTraceInfo, trace_id: int, message_span_id: int, user_id: str, status: Status, links: list
) -> SpanData:
"""Build message span for chatflow"""
return SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_IS_ENTRY: "true",
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
links=links,
)
@staticmethod
def _build_workflow_span(
trace_info: WorkflowTraceInfo,
trace_id: int,
workflow_span_id: int,
message_span_id: int | None,
user_id: str,
status: Status,
links: list,
) -> SpanData:
"""Build workflow span"""
attributes = {
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
}
if message_span_id is None:
attributes[GEN_AI_IS_ENTRY] = "true"
return SpanData(
trace_id=trace_id,
parent_span_id=message_span_id,
span_id=workflow_span_id,
name="workflow",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes=attributes,
status=status,
links=links,
)
@staticmethod
def build_workflow_llm_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build LLM span for workflow nodes."""
process_data = node_execution.process_data or {}
outputs = node_execution.outputs or {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name="GENERATION",
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
GEN_AI_PROVIDER: process_data.get("model_provider", ""),
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
GEN_AI_COMPLETION: str(outputs.get("text", "")),
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
OUTPUT_VALUE: str(outputs.get("text", "")),
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def build_message_span(
trace_info: MessageTraceInfo, trace_id: int, user_id: str, links: list | None = None
) -> SpanData:
"""Build message span."""
links = links or []
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
return SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message"),
name="message",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_IS_ENTRY: "true",
INPUT_VALUE: str(trace_info.inputs or ""),
OUTPUT_VALUE: str(trace_info.outputs or ""),
},
status=status,
links=links,
)
@staticmethod
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build tool span."""
status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
return SpanData(
trace_id=trace_id,
parent_span_id=parent_span_id,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "tool"),
name=trace_info.tool_name,
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: trace_info.tool_name,
TOOL_DESCRIPTION: "",
TOOL_PARAMETERS: json.dumps(trace_info.tool_parameters, ensure_ascii=False),
INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.tool_outputs),
},
status=status,
)
@staticmethod
def build_retrieval_span(trace_info: DatasetRetrievalTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
"""Build dataset retrieval span."""
status = Status(StatusCode.OK)
if getattr(trace_info, "error", None):
status = Status(StatusCode.ERROR, trace_info.error) # type: ignore[arg-type]
documents_data = TencentSpanBuilder._extract_retrieval_documents(trace_info.documents)
return SpanData(
trace_id=trace_id,
parent_span_id=parent_span_id,
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "retrieval"),
name="retrieval",
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: str(trace_info.inputs or ""),
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
INPUT_VALUE: str(trace_info.inputs or ""),
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
},
status=status,
)
@staticmethod
def _get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
"""Get workflow node execution status."""
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
return Status(StatusCode.OK)
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
return Status(StatusCode.ERROR, str(node_execution.error))
return Status(StatusCode.UNSET)
@staticmethod
def build_workflow_retrieval_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build knowledge retrieval span for workflow nodes."""
input_value = ""
if node_execution.inputs:
input_value = str(node_execution.inputs.get("query", ""))
output_value = ""
if node_execution.outputs:
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: input_value,
RETRIEVAL_DOCUMENT: output_value,
INPUT_VALUE: input_value,
OUTPUT_VALUE: output_value,
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def build_workflow_tool_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build tool span for workflow nodes."""
tool_des = {}
if node_execution.metadata:
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: node_execution.title,
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def build_workflow_task_span(
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
"""Build generic task span for workflow nodes."""
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
)
@staticmethod
def _extract_retrieval_documents(documents: list[Document]):
"""Extract documents data for retrieval tracing."""
documents_data = []
for document in documents:
document_data = {
"content": document.page_content,
"metadata": {
"dataset_id": document.metadata.get("dataset_id"),
"doc_id": document.metadata.get("doc_id"),
"document_id": document.metadata.get("document_id"),
},
"score": document.metadata.get("score"),
}
documents_data.append(document_data)
return documents_data

View File

@ -0,0 +1,317 @@
"""
Tencent APM tracing implementation with separated concerns
"""
import logging
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import TencentConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.client import TencentTraceClient
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
from core.ops.tencent_trace.utils import TencentTraceUtils
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.nodes import NodeType
from extensions.ext_database import db
from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
class TencentDataTrace(BaseTraceInstance):
"""
Tencent APM trace implementation with single responsibility principle.
Acts as a coordinator that delegates specific tasks to specialized classes.
"""
def __init__(self, tencent_config: TencentConfig):
super().__init__(tencent_config)
self.trace_client = TencentTraceClient(
service_name=tencent_config.service_name,
endpoint=tencent_config.endpoint,
token=tencent_config.token,
metrics_export_interval_sec=5,
)
def trace(self, trace_info: BaseTraceInfo) -> None:
"""Main tracing entry point - coordinates different trace types."""
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
elif isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
elif isinstance(trace_info, ModerationTraceInfo):
pass
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
elif isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
elif isinstance(trace_info, GenerateNameTraceInfo):
pass
def api_check(self) -> bool:
return self.trace_client.api_check()
def get_project_url(self) -> str:
return self.trace_client.get_project_url()
def workflow_trace(self, trace_info: WorkflowTraceInfo) -> None:
"""Handle workflow tracing by coordinating data retrieval and span construction."""
try:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.workflow_run_id)
links = []
if trace_info.trace_id:
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
user_id = self._get_user_id(trace_info)
workflow_spans = TencentSpanBuilder.build_workflow_spans(trace_info, trace_id, str(user_id), links)
for span in workflow_spans:
self.trace_client.add_span(span)
self._process_workflow_nodes(trace_info, trace_id)
except Exception:
logger.exception("[Tencent APM] Failed to process workflow trace")
def message_trace(self, trace_info: MessageTraceInfo) -> None:
"""Handle message tracing."""
try:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.message_id)
user_id = self._get_user_id(trace_info)
links = []
if trace_info.trace_id:
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
self.trace_client.add_span(message_span)
except Exception:
logger.exception("[Tencent APM] Failed to process message trace")
def tool_trace(self, trace_info: ToolTraceInfo) -> None:
"""Handle tool tracing."""
try:
parent_span_id = None
trace_root_id = None
if trace_info.message_id:
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
trace_root_id = trace_info.message_id
if parent_span_id and trace_root_id:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
tool_span = TencentSpanBuilder.build_tool_span(trace_info, trace_id, parent_span_id)
self.trace_client.add_span(tool_span)
except Exception:
logger.exception("[Tencent APM] Failed to process tool trace")
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo) -> None:
"""Handle dataset retrieval tracing."""
try:
parent_span_id = None
trace_root_id = None
if trace_info.message_id:
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
trace_root_id = trace_info.message_id
if parent_span_id and trace_root_id:
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
retrieval_span = TencentSpanBuilder.build_retrieval_span(trace_info, trace_id, parent_span_id)
self.trace_client.add_span(retrieval_span)
except Exception:
logger.exception("[Tencent APM] Failed to process dataset retrieval trace")
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo) -> None:
"""Handle suggested question tracing"""
try:
logger.info("[Tencent APM] Processing suggested question trace")
except Exception:
logger.exception("[Tencent APM] Failed to process suggested question trace")
def _process_workflow_nodes(self, trace_info: WorkflowTraceInfo, trace_id: int) -> None:
"""Process workflow node executions."""
try:
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
node_executions = self._get_workflow_node_executions(trace_info)
for node_execution in node_executions:
try:
node_span = self._build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
if node_span:
self.trace_client.add_span(node_span)
if node_execution.node_type == NodeType.LLM:
self._record_llm_metrics(node_execution)
except Exception:
logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id)
except Exception:
logger.exception("[Tencent APM] Failed to process workflow nodes")
def _build_workflow_node_span(
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
) -> SpanData | None:
"""Build span for different node types"""
try:
if node_execution.node_type == NodeType.LLM:
return TencentSpanBuilder.build_workflow_llm_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return TencentSpanBuilder.build_workflow_retrieval_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == NodeType.TOOL:
return TencentSpanBuilder.build_workflow_tool_span(
trace_id, workflow_span_id, trace_info, node_execution
)
else:
# Handle all other node types as generic tasks
return TencentSpanBuilder.build_workflow_task_span(
trace_id, workflow_span_id, trace_info, node_execution
)
except Exception:
logger.debug(
"[Tencent APM] Error building span for node %s: %s",
node_execution.id,
node_execution.node_type,
exc_info=True,
)
return None
def _get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> list[WorkflowNodeExecution]:
"""Retrieve workflow node executions from database."""
try:
session_maker = sessionmaker(bind=db.engine)
with Session(db.engine, expire_on_commit=False) as session:
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator")
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")
service_account.set_tenant_id(current_tenant.tenant_id)
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_maker,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
return list(executions)
except Exception:
logger.exception("[Tencent APM] Failed to get workflow node executions")
return []
def _get_user_id(self, trace_info: BaseTraceInfo) -> str:
"""Get user ID from trace info."""
try:
tenant_id = None
user_id = None
if isinstance(trace_info, (WorkflowTraceInfo, GenerateNameTraceInfo)):
tenant_id = trace_info.tenant_id
if hasattr(trace_info, "metadata") and trace_info.metadata:
user_id = trace_info.metadata.get("user_id")
if user_id and tenant_id:
stmt = (
select(Account.name)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(Account.id == user_id, TenantAccountJoin.tenant_id == tenant_id)
)
session_maker = sessionmaker(bind=db.engine)
with session_maker() as session:
account_name = session.scalar(stmt)
return account_name or str(user_id)
elif user_id:
return str(user_id)
return "anonymous"
except Exception:
logger.exception("[Tencent APM] Failed to get user ID")
return "unknown"
def _record_llm_metrics(self, node_execution: WorkflowNodeExecution) -> None:
"""Record LLM performance metrics"""
try:
if not hasattr(self.trace_client, "record_llm_duration"):
return
process_data = node_execution.process_data or {}
usage = process_data.get("usage", {})
latency_s = float(usage.get("latency", 0.0))
if latency_s > 0:
attributes = {
"provider": process_data.get("model_provider", ""),
"model": process_data.get("model_name", ""),
"span_kind": "GENERATION",
}
self.trace_client.record_llm_duration(latency_s, attributes)
except Exception:
logger.debug("[Tencent APM] Failed to record LLM metrics")
def __del__(self):
"""Ensure proper cleanup on garbage collection."""
try:
if hasattr(self, "trace_client"):
self.trace_client.shutdown()
except Exception:
pass

View File

@ -0,0 +1,65 @@
"""
Utility functions for Tencent APM tracing
"""
import hashlib
import random
import uuid
from datetime import datetime
from typing import cast
from opentelemetry.trace import Link, SpanContext, TraceFlags
class TencentTraceUtils:
"""Utility class for common tracing operations."""
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
@staticmethod
def convert_to_trace_id(uuid_v4: str | None) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
return cast(int, uuid_obj.int)
@staticmethod
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
combined_key = f"{uuid_obj.hex}-{span_type}"
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
@staticmethod
def generate_span_id() -> int:
span_id = random.getrandbits(64)
while span_id == TencentTraceUtils.INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return span_id
@staticmethod
def convert_datetime_to_nanoseconds(start_time: datetime | None) -> int:
if start_time is None:
start_time = datetime.now()
timestamp_in_seconds = start_time.timestamp()
return int(timestamp_in_seconds * 1e9)
@staticmethod
def create_link(trace_id_str: str) -> Link:
try:
trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else cast(int, uuid.UUID(trace_id_str).int)
except (ValueError, TypeError):
trace_id = cast(int, uuid.uuid4().int)
span_context = SpanContext(
trace_id=trace_id,
span_id=TencentTraceUtils.INVALID_SPAN_ID,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
return Link(span_context)

View File

@ -14,7 +14,7 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from extensions.ext_database import db
from models.account import Account
from models import Account
from models.model import App, AppMode, EndUser

View File

@ -2,7 +2,7 @@ import inspect
import json
import logging
from collections.abc import Callable, Generator
from typing import Any, TypeVar
from typing import Any, TypeVar, cast
import httpx
from pydantic import BaseModel
@ -31,6 +31,17 @@ from core.plugin.impl.exc import (
)
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
_plugin_daemon_timeout_config = cast(
float | httpx.Timeout | None,
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
)
plugin_daemon_request_timeout: httpx.Timeout | None
if _plugin_daemon_timeout_config is None:
plugin_daemon_request_timeout = None
elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
plugin_daemon_request_timeout = _plugin_daemon_timeout_config
else:
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
@ -58,6 +69,7 @@ class BasePluginClient:
"headers": headers,
"params": params,
"files": files,
"timeout": plugin_daemon_request_timeout,
}
if isinstance(prepared_data, dict):
request_kwargs["data"] = prepared_data
@ -116,6 +128,7 @@ class BasePluginClient:
"headers": headers,
"params": params,
"files": files,
"timeout": plugin_daemon_request_timeout,
}
if isinstance(prepared_data, dict):
stream_kwargs["data"] = prepared_data

View File

@ -1,9 +1,24 @@
"""
Weaviate vector database implementation for Dify's RAG system.
This module provides integration with Weaviate vector database for storing and retrieving
document embeddings used in retrieval-augmented generation workflows.
"""
import datetime
import json
import logging
import uuid as _uuid
from typing import Any
from urllib.parse import urlparse
import weaviate # type: ignore
import weaviate
import weaviate.classes.config as wc
from pydantic import BaseModel, model_validator
from weaviate.classes.data import DataObject
from weaviate.classes.init import Auth
from weaviate.classes.query import Filter, MetadataQuery
from weaviate.exceptions import UnexpectedStatusCodeError
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -15,265 +30,394 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class WeaviateConfig(BaseModel):
"""
Configuration model for Weaviate connection settings.
Attributes:
endpoint: Weaviate server endpoint URL
api_key: Optional API key for authentication
batch_size: Number of objects to batch per insert operation
"""
endpoint: str
api_key: str | None = None
batch_size: int = 100
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict) -> dict:
"""Validates that required configuration values are present."""
if not values["endpoint"]:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVector(BaseVector):
"""
Weaviate vector database implementation for document storage and retrieval.
Handles creation, insertion, deletion, and querying of document embeddings
in a Weaviate collection.
"""
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
"""
Initializes the Weaviate vector store.
Args:
collection_name: Name of the Weaviate collection
config: Weaviate configuration settings
attributes: List of metadata attributes to store
"""
super().__init__(collection_name)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.AuthApiKey(api_key=config.api_key or "")
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
"""
Initializes and returns a connected Weaviate client.
weaviate.connect.connection.has_grpc = False # ty: ignore [unresolved-attribute]
Configures both HTTP and gRPC connections with proper authentication.
"""
p = urlparse(config.endpoint)
host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "")
http_secure = p.scheme == "https"
http_port = p.port or (443 if http_secure else 80)
try:
client = weaviate.Client(
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
)
except Exception as exc:
raise ConnectionError("Vector database connection error") from exc
grpc_host = host
grpc_secure = http_secure
grpc_port = 443 if grpc_secure else 50051
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
client = weaviate.connect_to_custom(
http_host=host,
http_port=http_port,
http_secure=http_secure,
grpc_host=grpc_host,
grpc_port=grpc_port,
grpc_secure=grpc_secure,
auth_credentials=Auth.api_key(config.api_key) if config.api_key else None,
)
if not client.is_ready():
raise ConnectionError("Vector database is not ready")
return client
def get_type(self) -> str:
"""Returns the vector database type identifier."""
return VectorType.WEAVIATE
def get_collection_name(self, dataset: Dataset) -> str:
"""
Retrieves or generates the collection name for a dataset.
Uses existing index structure if available, otherwise generates from dataset ID.
"""
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
if not class_prefix.endswith("_Node"):
# original class_prefix
class_prefix += "_Node"
return class_prefix
dataset_id = dataset.id
return Dataset.gen_collection_name_by_id(dataset_id)
def to_index_struct(self):
def to_index_struct(self) -> dict:
"""Returns the index structure dictionary for persistence."""
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
# create collection
"""
Creates a new collection and adds initial documents with embeddings.
"""
self._create_collection()
# create vector
self.add_texts(texts, embeddings)
def _create_collection(self):
"""
Creates the Weaviate collection with required schema if it doesn't exist.
Uses Redis locking to prevent concurrent creation attempts.
"""
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(cache_key):
return
schema = self._default_schema(self._collection_name)
if not self._client.schema.contains(schema):
# create collection
self._client.schema.create_class(schema)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
try:
if not self._client.collections.exists(self._collection_name):
self._client.collections.create(
name=self._collection_name,
properties=[
wc.Property(
name=Field.TEXT_KEY.value,
data_type=wc.DataType.TEXT,
tokenization=wc.Tokenization.WORD,
),
wc.Property(name="document_id", data_type=wc.DataType.TEXT),
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
wc.Property(name="chunk_index", data_type=wc.DataType.INT),
],
vector_config=wc.Configure.Vectors.self_provided(),
)
self._ensure_properties()
redis_client.set(cache_key, 1, ex=3600)
except Exception as e:
logger.exception("Error creating collection %s", self._collection_name)
raise
def _ensure_properties(self) -> None:
"""
Ensures all required properties exist in the collection schema.
Adds missing properties if the collection exists but lacks them.
"""
if not self._client.collections.exists(self._collection_name):
return
col = self._client.collections.use(self._collection_name)
cfg = col.config.get()
existing = {p.name for p in (cfg.properties or [])}
to_add = []
if "document_id" not in existing:
to_add.append(wc.Property(name="document_id", data_type=wc.DataType.TEXT))
if "doc_id" not in existing:
to_add.append(wc.Property(name="doc_id", data_type=wc.DataType.TEXT))
if "chunk_index" not in existing:
to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT))
for prop in to_add:
try:
col.config.add_property(prop)
except Exception as e:
logger.warning("Could not add property %s: %s", prop.name, e)
def _get_uuids(self, documents: list[Document]) -> list[str]:
"""
Generates deterministic UUIDs for documents based on their content.
Uses UUID5 with URL namespace to ensure consistent IDs for identical content.
"""
URL_NAMESPACE = _uuid.UUID("6ba7b811-9dad-11d1-80b4-00c04fd430c8")
uuids = []
for doc in documents:
uuid_val = _uuid.uuid5(URL_NAMESPACE, doc.page_content)
uuids.append(str(uuid_val))
return uuids
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Adds documents with their embeddings to the collection.
Batches insertions for efficiency and returns the list of inserted object IDs.
"""
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
ids = []
col = self._client.collections.use(self._collection_name)
objs: list[DataObject] = []
ids_out: list[str] = []
with self._client.batch as batch:
for i, text in enumerate(texts):
data_properties = {Field.TEXT_KEY: text}
if metadatas is not None:
# metadata maybe None
for key, val in (metadatas[i] or {}).items():
data_properties[key] = self._json_serializable(val)
for i, text in enumerate(texts):
props: dict[str, Any] = {Field.TEXT_KEY.value: text}
meta = metadatas[i] or {}
for k, v in meta.items():
props[k] = self._json_serializable(v)
batch.add_data_object(
data_object=data_properties,
class_name=self._collection_name,
uuid=uuids[i],
vector=embeddings[i] if embeddings else None,
candidate = uuids[i] if uuids else None
uid = candidate if (candidate and self._is_uuid(candidate)) else str(_uuid.uuid4())
ids_out.append(uid)
vec_payload = None
if embeddings and i < len(embeddings) and embeddings[i]:
vec_payload = {"default": embeddings[i]}
objs.append(
DataObject(
uuid=uid,
properties=props, # type: ignore[arg-type] # mypy incorrectly infers DataObject signature
vector=vec_payload,
)
ids.append(uuids[i])
return ids
)
def delete_by_metadata_field(self, key: str, value: str):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
where_filter = {"operator": "Equal", "path": [key], "valueText": value}
batch_size = max(1, int(dify_config.WEAVIATE_BATCH_SIZE or 100))
with col.batch.dynamic() as batch:
for obj in objs:
batch.add_object(properties=obj.properties, uuid=obj.uuid, vector=obj.vector)
self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal")
return ids_out
def _is_uuid(self, val: str) -> bool:
"""Validates whether a string is a valid UUID format."""
try:
_uuid.UUID(str(val))
return True
except Exception:
return False
def delete_by_metadata_field(self, key: str, value: str) -> None:
"""Deletes all objects matching a specific metadata field value."""
if not self._client.collections.exists(self._collection_name):
return
col = self._client.collections.use(self._collection_name)
col.data.delete_many(where=Filter.by_property(key).equal(value))
def delete(self):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
self._client.schema.delete_class(self._collection_name)
"""Deletes the entire collection from Weaviate."""
if self._client.collections.exists(self._collection_name):
self._client.collections.delete(self._collection_name)
def text_exists(self, id: str) -> bool:
collection_name = self._collection_name
schema = self._default_schema(self._collection_name)
# check whether the index already exists
if not self._client.schema.contains(schema):
"""Checks if a document with the given doc_id exists in the collection."""
if not self._client.collections.exists(self._collection_name):
return False
result = (
self._client.query.get(collection_name)
.with_additional(["id"])
.with_where(
{
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}
)
.with_limit(1)
.do()
col = self._client.collections.use(self._collection_name)
res = col.query.fetch_objects(
filters=Filter.by_property("doc_id").equal(id),
limit=1,
return_properties=["doc_id"],
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
return len(res.objects) > 0
entries = result["data"]["Get"][collection_name]
if len(entries) == 0:
return False
def delete_by_ids(self, ids: list[str]) -> None:
"""
Deletes objects by their UUID identifiers.
return True
Silently ignores 404 errors for non-existent IDs.
"""
if not self._client.collections.exists(self._collection_name):
return
def delete_by_ids(self, ids: list[str]):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
for uuid in ids:
try:
self._client.data_object.delete(
class_name=self._collection_name,
uuid=uuid,
)
except weaviate.UnexpectedStatusCodeException as e:
# tolerate not found error
if e.status_code != 404:
raise e
col = self._client.collections.use(self._collection_name)
for uid in ids:
try:
col.data.delete_by_id(uid)
except UnexpectedStatusCodeError as e:
if getattr(e, "status_code", None) != 404:
raise
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate."""
collection_name = self._collection_name
properties = self._attributes
properties.append(Field.TEXT_KEY)
query_obj = self._client.query.get(collection_name, properties)
"""
Performs vector similarity search using the provided query vector.
vector = {"vector": query_vector}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
operands = []
for document_id_filter in document_ids_filter:
operands.append({"path": ["document_id"], "operator": "Equal", "valueText": document_id_filter})
where_filter = {"operator": "Or", "operands": operands}
query_obj = query_obj.with_where(where_filter)
result = (
query_obj.with_near_vector(vector)
.with_limit(kwargs.get("top_k", 4))
.with_additional(["vector", "distance"])
.do()
Filters by document IDs if provided and applies score threshold.
Returns documents sorted by relevance score.
"""
if not self._client.collections.exists(self._collection_name):
return []
col = self._client.collections.use(self._collection_name)
props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
top_k = int(kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
res = col.query.near_vector(
near_vector=query_vector,
limit=top_k,
return_properties=props,
return_metadata=MetadataQuery(distance=True),
include_vector=False,
filters=where,
target_vector="default",
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs_and_scores = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY)
score = 1 - res["_additional"]["distance"]
docs_and_scores.append((Document(page_content=text, metadata=res), score))
docs: list[Document] = []
for obj in res.objects:
properties = dict(obj.properties or {})
text = properties.pop(Field.TEXT_KEY.value, "")
distance = (obj.metadata.distance if obj.metadata else None) or 1.0
score = 1.0 - distance
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
if score > score_threshold:
properties["score"] = score
docs.append(Document(page_content=text, metadata=properties))
docs.sort(key=lambda d: d.metadata.get("score", 0.0), reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs using BM25F.
Args:
query: Text to look up documents similar to.
Returns:
List of Documents most similar to the query.
"""
collection_name = self._collection_name
content: dict[str, Any] = {"concepts": [query]}
properties = self._attributes
properties.append(Field.TEXT_KEY)
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(collection_name, properties)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
operands = []
for document_id_filter in document_ids_filter:
operands.append({"path": ["document_id"], "operator": "Equal", "valueText": document_id_filter})
where_filter = {"operator": "Or", "operands": operands}
query_obj = query_obj.with_where(where_filter)
query_obj = query_obj.with_additional(["vector"])
properties = ["text"]
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY)
additional = res.pop("_additional")
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
Performs BM25 full-text search on document content.
Filters by document IDs if provided and returns matching documents with vectors.
"""
if not self._client.collections.exists(self._collection_name):
return []
col = self._client.collections.use(self._collection_name)
props = list({*self._attributes, Field.TEXT_KEY.value})
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
top_k = int(kwargs.get("top_k", 4))
res = col.query.bm25(
query=query,
query_properties=[Field.TEXT_KEY.value],
limit=top_k,
return_properties=props,
include_vector=True,
filters=where,
)
docs: list[Document] = []
for obj in res.objects:
properties = dict(obj.properties or {})
text = properties.pop(Field.TEXT_KEY.value, "")
vec = obj.vector
if isinstance(vec, dict):
vec = vec.get("default") or next(iter(vec.values()), None)
docs.append(Document(page_content=text, vector=vec, metadata=properties))
return docs
def _default_schema(self, index_name: str):
return {
"class": index_name,
"properties": [
{
"name": "text",
"dataType": ["text"],
}
],
}
def _json_serializable(self, value: Any):
def _json_serializable(self, value: Any) -> Any:
"""Converts values to JSON-serializable format, handling datetime objects."""
if isinstance(value, datetime.datetime):
return value.isoformat()
return value
class WeaviateVectorFactory(AbstractVectorFactory):
"""Factory class for creating WeaviateVector instances."""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
"""
Initializes a WeaviateVector instance for the given dataset.
Uses existing collection name from dataset index structure or generates a new one.
Updates dataset index structure if not already set.
"""
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
@ -281,7 +425,6 @@ class WeaviateVectorFactory(AbstractVectorFactory):
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(

View File

@ -25,7 +25,7 @@ class FirecrawlApp:
}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v1/scrape", json_data, headers)
response = self._post_request(f"{self.base_url}/v2/scrape", json_data, headers)
if response.status_code == 200:
response_data = response.json()
data = response_data["data"]
@ -42,7 +42,7 @@ class FirecrawlApp:
json_data = {"url": url}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers)
response = self._post_request(f"{self.base_url}/v2/crawl", json_data, headers)
if response.status_code == 200:
# There's also another two fields in the response: "success" (bool) and "url" (str)
job_id = response.json().get("id")
@ -51,9 +51,25 @@ class FirecrawlApp:
self._handle_error(response, "start crawl job")
return "" # unreachable
def map(self, url: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/map
headers = self._prepare_headers()
json_data: dict[str, Any] = {"url": url, "integration": "dify"}
if params:
# Pass through provided params, including optional "sitemap": "only" | "include" | "skip"
json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/map", json_data, headers)
if response.status_code == 200:
return cast(dict[str, Any], response.json())
elif response.status_code in {402, 409, 500, 429, 408}:
self._handle_error(response, "start map job")
return {}
else:
raise Exception(f"Failed to start map job. Status code: {response.status_code}")
def check_crawl_status(self, job_id) -> dict[str, Any]:
headers = self._prepare_headers()
response = self._get_request(f"{self.base_url}/v1/crawl/{job_id}", headers)
response = self._get_request(f"{self.base_url}/v2/crawl/{job_id}", headers)
if response.status_code == 200:
crawl_status_response = response.json()
if crawl_status_response.get("status") == "completed":
@ -135,12 +151,16 @@ class FirecrawlApp:
"lang": "en",
"country": "us",
"timeout": 60000,
"ignoreInvalidURLs": False,
"ignoreInvalidURLs": True,
"scrapeOptions": {},
"sources": [
{"type": "web"},
],
"integration": "dify",
}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v1/search", json_data, headers)
response = self._post_request(f"{self.base_url}/v2/search", json_data, headers)
if response.status_code == 200:
response_data = response.json()
if not response_data.get("success"):

View File

@ -108,7 +108,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
execution_data = execution.model_dump()
# Queue the save operation as a Celery task (fire and forget)
save_workflow_execution_task.delay(
save_workflow_execution_task.delay( # type: ignore
execution_data=execution_data,
tenant_id=self._tenant_id,
app_id=self._app_id or "",

View File

@ -189,6 +189,11 @@ class ToolInvokeMessage(BaseModel):
data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log")
@field_validator("metadata", mode="before")
@classmethod
def _normalize_metadata(cls, value: Mapping[str, Any] | None) -> Mapping[str, Any]:
return value or {}
class RetrieverResourceMessage(BaseModel):
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
@ -376,6 +381,11 @@ class ToolEntity(BaseModel):
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
return v or []
@field_validator("output_schema", mode="before")
@classmethod
def _normalize_output_schema(cls, value: Mapping[str, object] | None) -> Mapping[str, object]:
return value or {}
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(

View File

@ -12,7 +12,7 @@ from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
from libs.login import current_user
from models.account import Account
from models import Account
logger = logging.getLogger(__name__)

View File

@ -41,6 +41,7 @@ class RedisChannel:
self._redis = redis_client
self._key = channel_key
self._command_ttl = command_ttl
self._pending_key = f"{channel_key}:pending"
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
@ -49,6 +50,9 @@ class RedisChannel:
Returns:
List of pending commands (drains the Redis list)
"""
if not self._has_pending_commands():
return []
commands: list[GraphEngineCommand] = []
# Use pipeline for atomic operations
@ -85,6 +89,7 @@ class RedisChannel:
with self._redis.pipeline() as pipe:
pipe.rpush(self._key, command_json)
pipe.expire(self._key, self._command_ttl)
pipe.set(self._pending_key, "1", ex=self._command_ttl)
pipe.execute()
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
@ -112,3 +117,17 @@ class RedisChannel:
except (ValueError, TypeError):
return None
def _has_pending_commands(self) -> bool:
"""
Check and consume the pending marker to avoid unnecessary list reads.
Returns:
True if commands should be fetched from Redis.
"""
with self._redis.pipeline() as pipe:
pipe.get(self._pending_key)
pipe.delete(self._pending_key)
pending_value, _ = pipe.execute()
return pending_value is not None

View File

@ -7,6 +7,7 @@ from collections.abc import Mapping
from functools import singledispatchmethod
from typing import TYPE_CHECKING, final
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import ErrorStrategy, NodeExecutionType
from core.workflow.graph import Graph
@ -125,6 +126,7 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
node_execution.mark_started(event.id)
self._graph_runtime_state.increment_node_run_steps()
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
@ -163,6 +165,8 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Store outputs in variable pool
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
@ -212,6 +216,8 @@ class EventHandler:
node_execution.mark_failed(event.error)
self._graph_execution.record_node_failure()
self._accumulate_node_usage(event.node_run_result.llm_usage)
result = self._error_handler.handle_node_failure(event)
if result:
@ -235,6 +241,8 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Persist outputs produced by the exception strategy (e.g. default values)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
@ -286,6 +294,19 @@ class EventHandler:
self._state_manager.enqueue_node(event.node_id)
self._state_manager.start_execution(event.node_id)
def _accumulate_node_usage(self, usage: LLMUsage) -> None:
"""Accumulate token usage into the shared runtime state."""
if usage.total_tokens <= 0:
return
self._graph_runtime_state.add_tokens(usage.total_tokens)
current_usage = self._graph_runtime_state.llm_usage
if current_usage.total_tokens == 0:
self._graph_runtime_state.llm_usage = usage
else:
self._graph_runtime_state.llm_usage = current_usage.plus(usage)
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
"""
Store node outputs in the variable pool.

View File

@ -8,7 +8,12 @@ import threading
import time
from typing import TYPE_CHECKING, final
from core.workflow.graph_events.base import GraphNodeEventBase
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunSucceededEvent,
)
from ..event_management import EventManager
from .execution_coordinator import ExecutionCoordinator
@ -72,13 +77,16 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
self._thread.join(timeout=10.0)
_COMMAND_TRIGGER_EVENTS = (
NodeRunSucceededEvent,
NodeRunFailedEvent,
NodeRunExceptionEvent,
)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""
try:
while not self._stop_event.is_set():
# Check for commands
self._execution_coordinator.check_commands()
# Check for scaling
self._execution_coordinator.check_scaling()
@ -87,6 +95,8 @@ class Dispatcher:
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self._event_handler.dispatch(event)
if self._should_check_commands(event):
self._execution_coordinator.check_commands()
self._event_queue.task_done()
except queue.Empty:
# Check if execution is complete
@ -102,3 +112,7 @@ class Dispatcher:
# Signal the event emitter that execution is complete
if self._event_emitter:
self._event_emitter.mark_complete()
def _should_check_commands(self, event: GraphNodeEventBase) -> bool:
"""Return True if the event represents a node completion."""
return isinstance(event, self._COMMAND_TRIGGER_EVENTS)

View File

@ -1,10 +1,13 @@
from events.dataset_event import dataset_was_deleted
from models import Dataset
from tasks.clean_dataset_task import clean_dataset_task
@dataset_was_deleted.connect
def handle(sender, **kwargs):
def handle(sender: Dataset, **kwargs):
dataset = sender
assert dataset.doc_form
assert dataset.indexing_technique
clean_dataset_task.delay(
dataset.id,
dataset.tenant_id,

Some files were not shown because too many files have changed in this diff Show More