Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine

This commit is contained in:
-LAN- 2025-09-11 15:13:31 +08:00
commit 85064bd8cf
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
105 changed files with 3132 additions and 568 deletions

View File

@ -19,11 +19,23 @@ jobs:
github.event.workflow_run.head_branch == 'deploy/enterprise'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.ENTERPRISE_SSH_HOST }}
username: ${{ secrets.ENTERPRISE_SSH_USER }}
password: ${{ secrets.ENTERPRISE_SSH_PASSWORD }}
script: |
${{ vars.ENTERPRISE_SSH_SCRIPT || secrets.ENTERPRISE_SSH_SCRIPT }}
- name: trigger deployments
env:
DEV_ENV_ADDRS: ${{ vars.DEV_ENV_ADDRS }}
DEPLOY_SECRET: ${{ secrets.DEPLOY_SECRET }}
run: |
IFS=',' read -ra ENDPOINTS <<< "${DEV_ENV_ADDRS:-}"
BODY='{"project":"dify-api","tag":"deploy-enterprise"}'
for ENDPOINT in "${ENDPOINTS[@]}"; do
ENDPOINT="$(echo "$ENDPOINT" | xargs)"
[ -z "$ENDPOINT" ] && continue
API_SIGNATURE=$(printf '%s' "$BODY" | openssl dgst -sha256 -hmac "$DEPLOY_SECRET" | awk '{print "sha256="$2}')
curl -sSf -X POST \
-H "Content-Type: application/json" \
-H "X-Hub-Signature-256: $API_SIGNATURE" \
-d "$BODY" \
"$ENDPOINT"
done

View File

@ -45,6 +45,7 @@ select = [
"G001", # don't use str format to logging messages
"G003", # don't use + in logging messages
"G004", # don't use f-strings to format logging messages
"UP042", # use StrEnum
]
ignore = [

View File

@ -212,7 +212,9 @@ def migrate_annotation_vector_database():
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
continue
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all()
annotations = db.session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
@ -367,29 +369,25 @@ def migrate_knowledge_vector_database():
)
raise e
dataset_documents = (
db.session.query(DatasetDocument)
.where(
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()
documents = []
segments_count = 0
for dataset_document in dataset_documents:
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
.all()
)
).all()
for segment in segments:
document = Document(

View File

@ -1,4 +1,5 @@
from flask import Blueprint
from flask_restx import Namespace
from libs.external_api import ExternalApi
@ -26,7 +27,16 @@ from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
bp = Blueprint("console", __name__, url_prefix="/console/api")
api = ExternalApi(bp)
api = ExternalApi(
bp,
version="1.0",
title="Console API",
description="Console management APIs for app configuration, monitoring, and administration",
)
# Create namespace
console_ns = Namespace("console", description="Console management API operations", path="/")
# File
api.add_resource(FileApi, "/files/upload")
@ -43,7 +53,16 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm"
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
# Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version # pyright: ignore[reportUnusedImport]
from . import (
admin, # pyright: ignore[reportUnusedImport]
apikey, # pyright: ignore[reportUnusedImport]
extension, # pyright: ignore[reportUnusedImport]
feature, # pyright: ignore[reportUnusedImport]
init_validate, # pyright: ignore[reportUnusedImport]
ping, # pyright: ignore[reportUnusedImport]
setup, # pyright: ignore[reportUnusedImport]
version, # pyright: ignore[reportUnusedImport]
)
# Import app controllers
from .app import (
@ -103,6 +122,23 @@ from .explore import (
saved_message, # pyright: ignore[reportUnusedImport]
)
# Import tag controllers
from .tag import tags # pyright: ignore[reportUnusedImport]
# Import workspace controllers
from .workspace import (
account, # pyright: ignore[reportUnusedImport]
agent_providers, # pyright: ignore[reportUnusedImport]
endpoint, # pyright: ignore[reportUnusedImport]
load_balancing_config, # pyright: ignore[reportUnusedImport]
members, # pyright: ignore[reportUnusedImport]
model_providers, # pyright: ignore[reportUnusedImport]
models, # pyright: ignore[reportUnusedImport]
plugin, # pyright: ignore[reportUnusedImport]
tool_providers, # pyright: ignore[reportUnusedImport]
workspace, # pyright: ignore[reportUnusedImport]
)
# Explore Audio
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
@ -174,19 +210,4 @@ api.add_resource(
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
)
# Import tag controllers
from .tag import tags # pyright: ignore[reportUnusedImport]
# Import workspace controllers
from .workspace import (
account, # pyright: ignore[reportUnusedImport]
agent_providers, # pyright: ignore[reportUnusedImport]
endpoint, # pyright: ignore[reportUnusedImport]
load_balancing_config, # pyright: ignore[reportUnusedImport]
members, # pyright: ignore[reportUnusedImport]
model_providers, # pyright: ignore[reportUnusedImport]
models, # pyright: ignore[reportUnusedImport]
plugin, # pyright: ignore[reportUnusedImport]
tool_providers, # pyright: ignore[reportUnusedImport]
workspace, # pyright: ignore[reportUnusedImport]
)
api.add_namespace(console_ns)

View File

@ -3,7 +3,7 @@ from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
@ -12,7 +12,7 @@ P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp
@ -45,7 +45,28 @@ def admin_required(view: Callable[P, R]):
return decorated
@console_ns.route("/admin/insert-explore-apps")
class InsertExploreAppListApi(Resource):
@api.doc("insert_explore_app")
@api.doc(description="Insert or update an app in the explore list")
@api.expect(
api.model(
"InsertExploreAppRequest",
{
"app_id": fields.String(required=True, description="Application ID"),
"desc": fields.String(description="App description"),
"copyright": fields.String(description="Copyright information"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"language": fields.String(required=True, description="Language code"),
"category": fields.String(required=True, description="App category"),
"position": fields.Integer(required=True, description="Display position"),
},
)
)
@api.response(200, "App updated successfully")
@api.response(201, "App inserted successfully")
@api.response(404, "App not found")
@only_edition_cloud
@admin_required
def post(self):
@ -115,7 +136,12 @@ class InsertExploreAppListApi(Resource):
return {"result": "success"}, 200
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
class InsertExploreAppApi(Resource):
@api.doc("delete_explore_app")
@api.doc(description="Remove an app from the explore list")
@api.doc(params={"app_id": "Application ID to remove"})
@api.response(204, "App removed successfully")
@only_edition_cloud
@admin_required
def delete(self, app_id):
@ -152,7 +178,3 @@ class InsertExploreAppApi(Resource):
db.session.commit()
return {"result": "success"}, 204
api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps")
api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/<uuid:app_id>")

View File

@ -14,7 +14,7 @@ from libs.login import login_required
from models.dataset import Dataset
from models.model import ApiToken, App
from . import api
from . import api, console_ns
from .wraps import account_initialization_required, setup_required
api_key_fields = {
@ -60,11 +60,11 @@ class BaseApiKeyListResource(Resource):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.all()
)
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
).all()
return {"items": keys}
@marshal_with(api_key_fields)
@ -135,7 +135,25 @@ class BaseApiKeyResource(Resource):
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
class AppApiKeyListResource(BaseApiKeyListResource):
@api.doc("get_app_api_keys")
@api.doc(description="Get all API keys for an app")
@api.doc(params={"resource_id": "App ID"})
@api.response(200, "Success", api_key_list)
def get(self, resource_id):
"""Get all API keys for an app"""
return super().get(resource_id)
@api.doc("create_app_api_key")
@api.doc(description="Create a new API key for an app")
@api.doc(params={"resource_id": "App ID"})
@api.response(201, "API key created successfully", api_key_fields)
@api.response(400, "Maximum keys exceeded")
def post(self, resource_id):
"""Create a new API key for an app"""
return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
@ -147,7 +165,16 @@ class AppApiKeyListResource(BaseApiKeyListResource):
token_prefix = "app-"
@console_ns.route("/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
class AppApiKeyResource(BaseApiKeyResource):
@api.doc("delete_app_api_key")
@api.doc(description="Delete an API key for an app")
@api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
@api.response(204, "API key deleted successfully")
def delete(self, resource_id, api_key_id):
"""Delete an API key for an app"""
return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
@ -158,7 +185,25 @@ class AppApiKeyResource(BaseApiKeyResource):
resource_id_field = "app_id"
@console_ns.route("/datasets/<uuid:resource_id>/api-keys")
class DatasetApiKeyListResource(BaseApiKeyListResource):
@api.doc("get_dataset_api_keys")
@api.doc(description="Get all API keys for a dataset")
@api.doc(params={"resource_id": "Dataset ID"})
@api.response(200, "Success", api_key_list)
def get(self, resource_id):
"""Get all API keys for a dataset"""
return super().get(resource_id)
@api.doc("create_dataset_api_key")
@api.doc(description="Create a new API key for a dataset")
@api.doc(params={"resource_id": "Dataset ID"})
@api.response(201, "API key created successfully", api_key_fields)
@api.response(400, "Maximum keys exceeded")
def post(self, resource_id):
"""Create a new API key for a dataset"""
return super().post(resource_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
@ -170,7 +215,16 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
token_prefix = "ds-"
@console_ns.route("/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
class DatasetApiKeyResource(BaseApiKeyResource):
@api.doc("delete_dataset_api_key")
@api.doc(description="Delete an API key for a dataset")
@api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
@api.response(204, "API key deleted successfully")
def delete(self, resource_id, api_key_id):
"""Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id)
def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*"
resp.headers["Access-Control-Allow-Credentials"] = "true"
@ -179,9 +233,3 @@ class DatasetApiKeyResource(BaseApiKeyResource):
resource_type = "dataset"
resource_model = Dataset
resource_id_field = "dataset_id"
api.add_resource(AppApiKeyListResource, "/apps/<uuid:resource_id>/api-keys")
api.add_resource(AppApiKeyResource, "/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
api.add_resource(DatasetApiKeyListResource, "/datasets/<uuid:resource_id>/api-keys")
api.add_resource(DatasetApiKeyResource, "/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")

View File

@ -2,7 +2,7 @@ import logging
from flask import request
from flask_restx import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import api
@ -105,6 +105,12 @@ class ChatMessageApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")

View File

@ -172,7 +172,7 @@ class MessageAnnotationApi(Resource):
def post(self, app_model):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()

View File

@ -2,8 +2,8 @@ import json
from typing import cast
from flask import request
from flask_login import current_user
from flask_restx import Resource
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.wraps import get_app_model
@ -13,7 +13,8 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from models.model import AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
@ -25,6 +26,13 @@ class ModelConfigResource(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model):
"""Modify app model config"""
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,

View File

@ -70,7 +70,7 @@ class DraftWorkflowApi(Resource):
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
# fetch draft workflow by app_model
@ -93,7 +93,7 @@ class DraftWorkflowApi(Resource):
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
content_type = request.headers.get("Content-Type", "")
@ -171,7 +171,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
if not isinstance(current_user, Account):
@ -221,7 +221,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -257,7 +257,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -294,7 +294,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -331,7 +331,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -368,7 +368,7 @@ class DraftWorkflowRunApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -407,7 +407,7 @@ class WorkflowTaskStopApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
# Stop using both mechanisms for backward compatibility
@ -434,7 +434,7 @@ class DraftWorkflowNodeRunApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -482,7 +482,7 @@ class PublishedWorkflowApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
# fetch published workflow by app_model
@ -503,7 +503,7 @@ class PublishedWorkflowApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -553,7 +553,7 @@ class DefaultBlockConfigsApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs
@ -573,7 +573,7 @@ class DefaultBlockConfigApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -608,7 +608,7 @@ class ConvertToWorkflowApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
if request.data:
@ -657,7 +657,7 @@ class PublishedAllWorkflowApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -708,7 +708,7 @@ class WorkflowByIdApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
parser = reqparse.RequestParser()
@ -764,7 +764,7 @@ class WorkflowByIdApi(Resource):
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
workflow_service = WorkflowService()

View File

@ -138,7 +138,7 @@ def _api_prerequisite(f):
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def wrapper(*args, **kwargs):
assert isinstance(current_user, Account)
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)

View File

@ -1,8 +1,8 @@
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from constants.languages import supported_language
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@ -10,14 +10,36 @@ from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus
from services.account_service import AccountService, RegisterService
active_check_parser = reqparse.RequestParser()
active_check_parser.add_argument(
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID"
)
active_check_parser.add_argument(
"email", type=email, required=False, nullable=True, location="args", help="Email address"
)
active_check_parser.add_argument(
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
)
@console_ns.route("/activate/check")
class ActivateCheckApi(Resource):
@api.doc("check_activation_token")
@api.doc(description="Check if activation token is valid")
@api.expect(active_check_parser)
@api.response(
200,
"Success",
api.model(
"ActivationCheckResponse",
{
"is_valid": fields.Boolean(description="Whether token is valid"),
"data": fields.Raw(description="Activation data if valid"),
},
),
)
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args")
parser.add_argument("email", type=email, required=False, nullable=True, location="args")
parser.add_argument("token", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
args = active_check_parser.parse_args()
workspaceId = args["workspace_id"]
reg_email = args["email"]
@ -38,18 +60,36 @@ class ActivateCheckApi(Resource):
return {"is_valid": False}
active_parser = reqparse.RequestParser()
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json")
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json")
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
active_parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"
)
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
@console_ns.route("/activate")
class ActivateApi(Resource):
@api.doc("activate_account")
@api.doc(description="Activate account with invitation token")
@api.expect(active_parser)
@api.response(
200,
"Account activated successfully",
api.model(
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
"data": fields.Raw(description="Login token data"),
},
),
)
@api.response(400, "Already activated or invalid token")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"
)
parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
args = parser.parse_args()
args = active_parser.parse_args()
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
if invitation is None:
@ -70,7 +110,3 @@ class ActivateApi(Resource):
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
return {"result": "success", "data": token_pair.model_dump()}
api.add_resource(ActivateCheckApi, "/activate/check")
api.add_resource(ActivateApi, "/activate")

View File

@ -3,11 +3,11 @@ import logging
import requests
from flask import current_app, redirect, request
from flask_login import current_user
from flask_restx import Resource
from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console import api, console_ns
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
@ -28,7 +28,21 @@ def get_oauth_providers():
return OAUTH_PROVIDERS
@console_ns.route("/oauth/data-source/<string:provider>")
class OAuthDataSource(Resource):
@api.doc("oauth_data_source")
@api.doc(description="Get OAuth authorization URL for data source provider")
@api.doc(params={"provider": "Data source provider name (notion)"})
@api.response(
200,
"Authorization URL or internal setup success",
api.model(
"OAuthDataSourceResponse",
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
),
)
@api.response(400, "Invalid provider")
@api.response(403, "Admin privileges required")
def get(self, provider: str):
# The role of the current user in the table must be admin or owner
if not current_user.is_admin_or_owner:
@ -49,7 +63,19 @@ class OAuthDataSource(Resource):
return {"data": auth_url}, 200
@console_ns.route("/oauth/data-source/callback/<string:provider>")
class OAuthDataSourceCallback(Resource):
@api.doc("oauth_data_source_callback")
@api.doc(description="Handle OAuth callback from data source provider")
@api.doc(
params={
"provider": "Data source provider name (notion)",
"code": "Authorization code from OAuth provider",
"error": "Error message from OAuth provider",
}
)
@api.response(302, "Redirect to console with result")
@api.response(400, "Invalid provider")
def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context():
@ -68,7 +94,19 @@ class OAuthDataSourceCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied")
@console_ns.route("/oauth/data-source/binding/<string:provider>")
class OAuthDataSourceBinding(Resource):
@api.doc("oauth_data_source_binding")
@api.doc(description="Bind OAuth data source with authorization code")
@api.doc(
params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"}
)
@api.response(
200,
"Data source binding success",
api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
)
@api.response(400, "Invalid provider or code")
def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context():
@ -90,7 +128,17 @@ class OAuthDataSourceBinding(Resource):
return {"result": "success"}, 200
@console_ns.route("/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
class OAuthDataSourceSync(Resource):
@api.doc("oauth_data_source_sync")
@api.doc(description="Sync data from OAuth data source")
@api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
@api.response(
200,
"Data source sync success",
api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
)
@api.response(400, "Invalid provider or sync failed")
@setup_required
@login_required
@account_initialization_required
@ -111,9 +159,3 @@ class OAuthDataSourceSync(Resource):
return {"error": "OAuth data source process failed"}, 400
return {"result": "success"}, 200
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")
api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")

View File

@ -2,12 +2,12 @@ import base64
import secrets
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.languages import languages
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.auth.error import (
EmailCodeError,
EmailPasswordResetLimitError,
@ -28,7 +28,32 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces
from services.feature_service import FeatureService
@console_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource):
@api.doc("send_forgot_password_email")
@api.doc(description="Send password reset email")
@api.expect(
api.model(
"ForgotPasswordEmailRequest",
{
"email": fields.String(required=True, description="Email address"),
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
},
)
)
@api.response(
200,
"Email sent successfully",
api.model(
"ForgotPasswordEmailResponse",
{
"result": fields.String(description="Operation result"),
"data": fields.String(description="Reset token"),
"code": fields.String(description="Error code if account not found"),
},
),
)
@api.response(400, "Invalid email or rate limit exceeded")
@setup_required
@email_password_login_enabled
def post(self):
@ -61,7 +86,33 @@ class ForgotPasswordSendEmailApi(Resource):
return {"result": "success", "data": token}
@console_ns.route("/forgot-password/validity")
class ForgotPasswordCheckApi(Resource):
@api.doc("check_forgot_password_code")
@api.doc(description="Verify password reset code")
@api.expect(
api.model(
"ForgotPasswordCheckRequest",
{
"email": fields.String(required=True, description="Email address"),
"code": fields.String(required=True, description="Verification code"),
"token": fields.String(required=True, description="Reset token"),
},
)
)
@api.response(
200,
"Code verified successfully",
api.model(
"ForgotPasswordCheckResponse",
{
"is_valid": fields.Boolean(description="Whether code is valid"),
"email": fields.String(description="Email address"),
"token": fields.String(description="New reset token"),
},
),
)
@api.response(400, "Invalid code or token")
@setup_required
@email_password_login_enabled
def post(self):
@ -100,7 +151,26 @@ class ForgotPasswordCheckApi(Resource):
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@console_ns.route("/forgot-password/resets")
class ForgotPasswordResetApi(Resource):
@api.doc("reset_password")
@api.doc(description="Reset password with verification token")
@api.expect(
api.model(
"ForgotPasswordResetRequest",
{
"token": fields.String(required=True, description="Verification token"),
"new_password": fields.String(required=True, description="New password"),
"password_confirm": fields.String(required=True, description="Password confirmation"),
},
)
)
@api.response(
200,
"Password reset successfully",
api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
)
@api.response(400, "Invalid token or password mismatch")
@setup_required
@email_password_login_enabled
def post(self):
@ -172,8 +242,3 @@ class ForgotPasswordResetApi(Resource):
pass
except AccountRegisterError:
raise AccountInFreezeError()
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")

View File

@ -22,7 +22,7 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
from services.feature_service import FeatureService
from .. import api
from .. import api, console_ns
logger = logging.getLogger(__name__)
@ -50,7 +50,13 @@ def get_oauth_providers():
return OAUTH_PROVIDERS
@console_ns.route("/oauth/login/<provider>")
class OAuthLogin(Resource):
@api.doc("oauth_login")
@api.doc(description="Initiate OAuth login process")
@api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"})
@api.response(302, "Redirect to OAuth authorization URL")
@api.response(400, "Invalid provider")
def get(self, provider: str):
invite_token = request.args.get("invite_token") or None
OAUTH_PROVIDERS = get_oauth_providers()
@ -63,7 +69,19 @@ class OAuthLogin(Resource):
return redirect(auth_url)
@console_ns.route("/oauth/authorize/<provider>")
class OAuthCallback(Resource):
@api.doc("oauth_callback")
@api.doc(description="Handle OAuth callback and complete login process")
@api.doc(
params={
"provider": "OAuth provider name (github/google)",
"code": "Authorization code from OAuth provider",
"state": "Optional state parameter (used for invite token)",
}
)
@api.response(302, "Redirect to console with access token")
@api.response(400, "OAuth process failed")
def get(self, provider: str):
OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context():
@ -184,7 +202,3 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
AccountService.link_account_integrate(provider, user_info.id, account)
return account
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")

View File

@ -29,14 +29,12 @@ class DataSourceApi(Resource):
@marshal_with(integrate_list_fields)
def get(self):
# get workspace data source integrates
data_source_integrates = (
db.session.query(DataSourceOauthBinding)
.where(
data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False,
)
.all()
)
).all()
base_url = request.url_root.rstrip("/")
data_source_oauth_base_path = "/console/api/oauth/data-source"

View File

@ -2,6 +2,7 @@ import flask_restx
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -411,11 +412,11 @@ class DatasetIndexingEstimateApi(Resource):
extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
.all()
)
file_details = db.session.scalars(
select(UploadFile).where(
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
)
).all()
if file_details is None:
raise NotFound("File not found.")
@ -518,11 +519,11 @@ class DatasetIndexingStatusApi(Resource):
@account_initialization_required
def get(self, dataset_id):
dataset_id = str(dataset_id)
documents = (
db.session.query(Document)
.where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
.all()
)
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
)
).all()
documents_status = []
for document in documents:
completed_segments = (
@ -569,11 +570,11 @@ class DatasetApiKeyApi(Resource):
@account_initialization_required
@marshal_with(api_key_list)
def get(self):
keys = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.all()
)
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
)
).all()
return {"items": keys}
@setup_required

View File

@ -1,5 +1,6 @@
import logging
from argparse import ArgumentTypeError
from collections.abc import Sequence
from typing import Literal, cast
from flask import request
@ -79,7 +80,7 @@ class DocumentResource(Resource):
return document
def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")

View File

@ -3,7 +3,7 @@ from typing import Any
from flask import request
from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.console import api
@ -33,13 +33,15 @@ class InstalledAppsListApi(Resource):
current_tenant_id = current_user.current_tenant_id
if app_id:
installed_apps = (
db.session.query(InstalledApp)
.where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.all()
)
installed_apps = db.session.scalars(
select(InstalledApp).where(
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)
)
).all()
else:
installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()
installed_apps = db.session.scalars(
select(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id)
).all()
if current_user.current_tenant is None:
raise ValueError("current_user.current_tenant must not be None")

View File

@ -1,8 +1,8 @@
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, fields, marshal_with, reqparse
from constants import HIDDEN_VALUE
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import login_required
@ -11,7 +11,21 @@ from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
@console_ns.route("/code-based-extension")
class CodeBasedExtensionAPI(Resource):
@api.doc("get_code_based_extension")
@api.doc(description="Get code-based extension data by module name")
@api.expect(
api.parser().add_argument("module", type=str, required=True, location="args", help="Extension module name")
)
@api.response(
200,
"Success",
api.model(
"CodeBasedExtensionResponse",
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
),
)
@setup_required
@login_required
@account_initialization_required
@ -23,7 +37,11 @@ class CodeBasedExtensionAPI(Resource):
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
@console_ns.route("/api-based-extension")
class APIBasedExtensionAPI(Resource):
@api.doc("get_api_based_extensions")
@api.doc(description="Get all API-based extensions for current tenant")
@api.response(200, "Success", fields.List(fields.Nested(api_based_extension_fields)))
@setup_required
@login_required
@account_initialization_required
@ -32,6 +50,19 @@ class APIBasedExtensionAPI(Resource):
tenant_id = current_user.current_tenant_id
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@api.doc("create_api_based_extension")
@api.doc(description="Create a new API-based extension")
@api.expect(
api.model(
"CreateAPIBasedExtensionRequest",
{
"name": fields.String(required=True, description="Extension name"),
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
"api_key": fields.String(required=True, description="API key for authentication"),
},
)
)
@api.response(201, "Extension created successfully", api_based_extension_fields)
@setup_required
@login_required
@account_initialization_required
@ -53,7 +84,12 @@ class APIBasedExtensionAPI(Resource):
return APIBasedExtensionService.save(extension_data)
@console_ns.route("/api-based-extension/<uuid:id>")
class APIBasedExtensionDetailAPI(Resource):
@api.doc("get_api_based_extension")
@api.doc(description="Get API-based extension by ID")
@api.doc(params={"id": "Extension ID"})
@api.response(200, "Success", api_based_extension_fields)
@setup_required
@login_required
@account_initialization_required
@ -64,6 +100,20 @@ class APIBasedExtensionDetailAPI(Resource):
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
@api.doc("update_api_based_extension")
@api.doc(description="Update API-based extension")
@api.doc(params={"id": "Extension ID"})
@api.expect(
api.model(
"UpdateAPIBasedExtensionRequest",
{
"name": fields.String(required=True, description="Extension name"),
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
"api_key": fields.String(required=True, description="API key for authentication"),
},
)
)
@api.response(200, "Extension updated successfully", api_based_extension_fields)
@setup_required
@login_required
@account_initialization_required
@ -88,6 +138,10 @@ class APIBasedExtensionDetailAPI(Resource):
return APIBasedExtensionService.save(extension_data_from_db)
@api.doc("delete_api_based_extension")
@api.doc(description="Delete API-based extension")
@api.doc(params={"id": "Extension ID"})
@api.response(204, "Extension deleted successfully")
@setup_required
@login_required
@account_initialization_required
@ -100,9 +154,3 @@ class APIBasedExtensionDetailAPI(Resource):
APIBasedExtensionService.delete(extension_data_from_db)
return {"result": "success"}, 204
api.add_resource(CodeBasedExtensionAPI, "/code-based-extension")
api.add_resource(APIBasedExtensionAPI, "/api-based-extension")
api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/<uuid:id>")

View File

@ -1,26 +1,40 @@
from flask_login import current_user
from flask_restx import Resource
from flask_restx import Resource, fields
from libs.login import login_required
from services.feature_service import FeatureService
from . import api
from . import api, console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required
@console_ns.route("/features")
class FeatureApi(Resource):
@api.doc("get_tenant_features")
@api.doc(description="Get feature configuration for current tenant")
@api.response(
200,
"Success",
api.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
)
@setup_required
@login_required
@account_initialization_required
@cloud_utm_record
def get(self):
"""Get feature configuration for current tenant"""
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
@console_ns.route("/system-features")
class SystemFeatureApi(Resource):
@api.doc("get_system_features")
@api.doc(description="Get system-wide feature configuration")
@api.response(
200,
"Success",
api.model("SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}),
)
def get(self):
"""Get system-wide feature configuration"""
return FeatureService.get_system_features().model_dump()
api.add_resource(FeatureApi, "/features")
api.add_resource(SystemFeatureApi, "/system-features")

View File

@ -1,7 +1,7 @@
import os
from flask import session
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -11,20 +11,47 @@ from libs.helper import StrLen
from models.model import DifySetup
from services.account_service import TenantService
from . import api
from . import api, console_ns
from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted
@console_ns.route("/init")
class InitValidateAPI(Resource):
@api.doc("get_init_status")
@api.doc(description="Get initialization validation status")
@api.response(
200,
"Success",
model=api.model(
"InitStatusResponse",
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
),
)
def get(self):
"""Get initialization validation status"""
init_status = get_init_validate_status()
if init_status:
return {"status": "finished"}
return {"status": "not_started"}
@api.doc("validate_init_password")
@api.doc(description="Validate initialization password for self-hosted edition")
@api.expect(
api.model(
"InitValidateRequest",
{"password": fields.String(required=True, description="Initialization password", max_length=30)},
)
)
@api.response(
201,
"Success",
model=api.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
)
@api.response(400, "Already setup or validation failed")
@only_edition_self_hosted
def post(self):
"""Validate initialization password"""
# is tenant created
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
@ -52,6 +79,3 @@ def get_init_validate_status():
return db_session.execute(select(DifySetup)).scalar_one_or_none()
return True
api.add_resource(InitValidateAPI, "/init")

View File

@ -1,14 +1,17 @@
from flask_restx import Resource
from flask_restx import Resource, fields
from controllers.console import api
from . import api, console_ns
@console_ns.route("/ping")
class PingApi(Resource):
@api.doc("health_check")
@api.doc(description="Health check endpoint for connection testing")
@api.response(
200,
"Success",
api.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
)
def get(self):
"""
For connection health check
"""
"""Health check endpoint for connection testing"""
return {"result": "pong"}
api.add_resource(PingApi, "/ping")

View File

@ -1,5 +1,5 @@
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip
@ -7,23 +7,56 @@ from libs.password import valid_password
from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService
from . import api
from . import api, console_ns
from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted
@console_ns.route("/setup")
class SetupApi(Resource):
@api.doc("get_setup_status")
@api.doc(description="Get system setup status")
@api.response(
200,
"Success",
api.model(
"SetupStatusResponse",
{
"step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
"setup_at": fields.String(description="Setup completion time (ISO format)", required=False),
},
),
)
def get(self):
"""Get system setup status"""
if dify_config.EDITION == "SELF_HOSTED":
setup_status = get_setup_status()
if setup_status:
# Check if setup_status is a DifySetup object rather than a bool
if setup_status and not isinstance(setup_status, bool):
return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
elif setup_status:
return {"step": "finished"}
return {"step": "not_started"}
return {"step": "finished"}
@api.doc("setup_system")
@api.doc(description="Initialize system setup with admin account")
@api.expect(
api.model(
"SetupRequest",
{
"email": fields.String(required=True, description="Admin email address"),
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
"password": fields.String(required=True, description="Admin password"),
},
)
)
@api.response(201, "Success", api.model("SetupResponse", {"result": fields.String(description="Setup result")}))
@api.response(400, "Already setup or validation failed")
@only_edition_self_hosted
def post(self):
"""Initialize system setup with admin account"""
# is set up
if get_setup_status():
raise AlreadySetupError()
@ -55,6 +88,3 @@ def get_setup_status():
return db.session.query(DifySetup).first()
else:
return True
api.add_resource(SetupApi, "/setup")

View File

@ -2,18 +2,41 @@ import json
import logging
import requests
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from packaging import version
from configs import dify_config
from . import api
from . import api, console_ns
logger = logging.getLogger(__name__)
@console_ns.route("/version")
class VersionApi(Resource):
@api.doc("check_version_update")
@api.doc(description="Check for application version updates")
@api.expect(
api.parser().add_argument(
"current_version", type=str, required=True, location="args", help="Current application version"
)
)
@api.response(
200,
"Success",
api.model(
"VersionResponse",
{
"version": fields.String(description="Latest version number"),
"release_date": fields.String(description="Release date of latest version"),
"release_notes": fields.String(description="Release notes for latest version"),
"can_auto_update": fields.Boolean(description="Whether auto-update is supported"),
"features": fields.Raw(description="Feature flags and capabilities"),
},
),
)
def get(self):
"""Check for application version updates"""
parser = reqparse.RequestParser()
parser.add_argument("current_version", type=str, required=True, location="args")
args = parser.parse_args()
@ -59,6 +82,3 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool:
except version.InvalidVersion:
logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version)
return False
api.add_resource(VersionApi, "/version")

View File

@ -248,7 +248,9 @@ class AccountIntegrateApi(Resource):
raise ValueError("Invalid user account")
account = current_user
account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
account_integrates = db.session.scalars(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
).all()
base_url = request.url_root.rstrip("/")
oauth_base_path = "/console/api/oauth/login"

View File

@ -1,14 +1,22 @@
from flask_login import current_user
from flask_restx import Resource
from flask_restx import Resource, fields
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.agent_service import AgentService
@console_ns.route("/workspaces/current/agent-providers")
class AgentProviderListApi(Resource):
@api.doc("list_agent_providers")
@api.doc(description="Get list of available agent providers")
@api.response(
200,
"Success",
fields.List(fields.Raw(description="Agent provider information")),
)
@setup_required
@login_required
@account_initialization_required
@ -21,7 +29,16 @@ class AgentProviderListApi(Resource):
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
class AgentProviderApi(Resource):
@api.doc("get_agent_provider")
@api.doc(description="Get specific agent provider details")
@api.doc(params={"provider_name": "Agent provider name"})
@api.response(
200,
"Success",
fields.Raw(description="Agent provider details"),
)
@setup_required
@login_required
@account_initialization_required
@ -30,7 +47,3 @@ class AgentProviderApi(Resource):
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers")
api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>")

View File

@ -1,8 +1,8 @@
from flask_login import current_user
from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError
@ -10,7 +10,26 @@ from libs.login import login_required
from services.plugin.endpoint_service import EndpointService
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource):
@api.doc("create_endpoint")
@api.doc(description="Create a new plugin endpoint")
@api.expect(
api.model(
"EndpointCreateRequest",
{
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
"settings": fields.Raw(required=True, description="Endpoint settings"),
"name": fields.String(required=True, description="Endpoint name"),
},
)
)
@api.response(
200,
"Endpoint created successfully",
api.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
)
@api.response(403, "Admin privileges required")
@setup_required
@login_required
@account_initialization_required
@ -43,7 +62,20 @@ class EndpointCreateApi(Resource):
raise ValueError(e.description) from e
@console_ns.route("/workspaces/current/endpoints/list")
class EndpointListApi(Resource):
@api.doc("list_endpoints")
@api.doc(description="List plugin endpoints with pagination")
@api.expect(
api.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
)
@api.response(
200,
"Success",
api.model("EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}),
)
@setup_required
@login_required
@account_initialization_required
@ -70,7 +102,23 @@ class EndpointListApi(Resource):
)
@console_ns.route("/workspaces/current/endpoints/list/plugin")
class EndpointListForSinglePluginApi(Resource):
@api.doc("list_plugin_endpoints")
@api.doc(description="List endpoints for a specific plugin")
@api.expect(
api.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
)
@api.response(
200,
"Success",
api.model(
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
)
@setup_required
@login_required
@account_initialization_required
@ -100,7 +148,19 @@ class EndpointListForSinglePluginApi(Resource):
)
@console_ns.route("/workspaces/current/endpoints/delete")
class EndpointDeleteApi(Resource):
@api.doc("delete_endpoint")
@api.doc(description="Delete a plugin endpoint")
@api.expect(
api.model("EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
)
@api.response(
200,
"Endpoint deleted successfully",
api.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
)
@api.response(403, "Admin privileges required")
@setup_required
@login_required
@account_initialization_required
@ -123,7 +183,26 @@ class EndpointDeleteApi(Resource):
}
@console_ns.route("/workspaces/current/endpoints/update")
class EndpointUpdateApi(Resource):
@api.doc("update_endpoint")
@api.doc(description="Update a plugin endpoint")
@api.expect(
api.model(
"EndpointUpdateRequest",
{
"endpoint_id": fields.String(required=True, description="Endpoint ID"),
"settings": fields.Raw(required=True, description="Updated settings"),
"name": fields.String(required=True, description="Updated name"),
},
)
)
@api.response(
200,
"Endpoint updated successfully",
api.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
)
@api.response(403, "Admin privileges required")
@setup_required
@login_required
@account_initialization_required
@ -154,7 +233,19 @@ class EndpointUpdateApi(Resource):
}
@console_ns.route("/workspaces/current/endpoints/enable")
class EndpointEnableApi(Resource):
@api.doc("enable_endpoint")
@api.doc(description="Enable a plugin endpoint")
@api.expect(
api.model("EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
)
@api.response(
200,
"Endpoint enabled successfully",
api.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
)
@api.response(403, "Admin privileges required")
@setup_required
@login_required
@account_initialization_required
@ -177,7 +268,19 @@ class EndpointEnableApi(Resource):
}
@console_ns.route("/workspaces/current/endpoints/disable")
class EndpointDisableApi(Resource):
@api.doc("disable_endpoint")
@api.doc(description="Disable a plugin endpoint")
@api.expect(
api.model("EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
)
@api.response(
200,
"Endpoint disabled successfully",
api.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
)
@api.response(403, "Admin privileges required")
@setup_required
@login_required
@account_initialization_required
@ -198,12 +301,3 @@ class EndpointDisableApi(Resource):
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
}
api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create")
api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list")
api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin")
api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete")
api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update")
api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable")
api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable")

View File

@ -10,7 +10,6 @@ api = ExternalApi(
version="1.0",
title="Files API",
description="API for file operations including upload and preview",
doc="/docs", # Enable Swagger UI at /files/docs
)
files_ns = Namespace("files", description="File operations", path="/")

View File

@ -10,7 +10,6 @@ api = ExternalApi(
version="1.0",
title="Inner API",
description="Internal APIs for enterprise features, billing, and plugin communication",
doc="/docs", # Enable Swagger UI at /inner/api/docs
)
# Create namespace

View File

@ -75,9 +75,6 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
if not user_id:
user_id = DEFAULT_SERVICE_API_USER_ID
del kwargs["tenant_id"]
del kwargs["user_id"]
try:
tenant_model = (
db.session.query(Tenant)

View File

@ -10,7 +10,6 @@ api = ExternalApi(
version="1.0",
title="MCP API",
description="API for Model Context Protocol operations",
doc="/docs", # Enable Swagger UI at /mcp/docs
)
mcp_ns = Namespace("mcp", description="MCP operations", path="/")

View File

@ -10,7 +10,6 @@ api = ExternalApi(
version="1.0",
title="Service API",
description="API for application services",
doc="/docs", # Enable Swagger UI at /v1/docs
)
service_api_ns = Namespace("service_api", description="Service operations", path="/")

View File

@ -165,7 +165,7 @@ class AnnotationUpdateDeleteApi(Resource):
def put(self, app_model: App, annotation_id):
"""Update an existing annotation."""
assert isinstance(current_user, Account)
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
annotation_id = str(annotation_id)
@ -189,7 +189,7 @@ class AnnotationUpdateDeleteApi(Resource):
"""Delete an annotation."""
assert isinstance(current_user, Account)
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
annotation_id = str(annotation_id)

View File

@ -559,7 +559,7 @@ class DatasetTagsApi(DatasetApiResource):
def post(self, _, dataset_id):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_create_parser.parse_args()
@ -583,7 +583,7 @@ class DatasetTagsApi(DatasetApiResource):
@validate_dataset_token
def patch(self, _, dataset_id):
assert isinstance(current_user, Account)
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_update_parser.parse_args()
@ -610,7 +610,7 @@ class DatasetTagsApi(DatasetApiResource):
def delete(self, _, dataset_id):
"""Delete a knowledge type tag."""
assert isinstance(current_user, Account)
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
args = tag_delete_parser.parse_args()
TagService.delete_tag(args.get("tag_id"))
@ -634,7 +634,7 @@ class DatasetTagBindingApi(DatasetApiResource):
def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_binding_parser.parse_args()
@ -660,7 +660,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
def post(self, _, dataset_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
assert isinstance(current_user, Account)
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_unbinding_parser.parse_args()

View File

@ -10,7 +10,6 @@ api = ExternalApi(
version="1.0",
title="Web API",
description="Public APIs for web applications including file uploads, chat interactions, and app management",
doc="/docs", # Enable Swagger UI at /api/docs
)
# Create namespace

View File

@ -5,7 +5,7 @@ from flask_restx import fields, marshal_with, reqparse
from werkzeug.exceptions import InternalServerError
import services
from controllers.web import api
from controllers.web import web_ns
from controllers.web.error import (
AppUnavailableError,
AudioTooLargeError,
@ -32,15 +32,16 @@ from services.errors.audio import (
logger = logging.getLogger(__name__)
@web_ns.route("/audio-to-text")
class AudioApi(WebApiResource):
audio_to_text_response_fields = {
"text": fields.String,
}
@marshal_with(audio_to_text_response_fields)
@api.doc("Audio to Text")
@api.doc(description="Convert audio file to text using speech-to-text service.")
@api.doc(
@web_ns.doc("Audio to Text")
@web_ns.doc(description="Convert audio file to text using speech-to-text service.")
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
@ -85,6 +86,7 @@ class AudioApi(WebApiResource):
raise InternalServerError()
@web_ns.route("/text-to-audio")
class TextApi(WebApiResource):
text_to_audio_response_fields = {
"audio_url": fields.String,
@ -92,9 +94,9 @@ class TextApi(WebApiResource):
}
@marshal_with(text_to_audio_response_fields)
@api.doc("Text to Audio")
@api.doc(description="Convert text to audio using text-to-speech service.")
@api.doc(
@web_ns.doc("Text to Audio")
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
@ -145,7 +147,3 @@ class TextApi(WebApiResource):
except Exception as e:
logger.exception("Failed to handle post request to TextApi")
raise InternalServerError()
api.add_resource(AudioApi, "/audio-to-text")
api.add_resource(TextApi, "/text-to-audio")

View File

@ -4,7 +4,7 @@ from flask_restx import reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.web import api
from controllers.web import web_ns
from controllers.web.error import (
AppUnavailableError,
CompletionRequestError,
@ -35,10 +35,11 @@ logger = logging.getLogger(__name__)
# define completion api for user
@web_ns.route("/completion-messages")
class CompletionApi(WebApiResource):
@api.doc("Create Completion Message")
@api.doc(description="Create a completion message for text generation applications.")
@api.doc(
@web_ns.doc("Create Completion Message")
@web_ns.doc(description="Create a completion message for text generation applications.")
@web_ns.doc(
params={
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
"query": {"description": "Query text for completion", "type": "string", "required": False},
@ -52,7 +53,7 @@ class CompletionApi(WebApiResource):
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@api.doc(
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
@ -106,11 +107,12 @@ class CompletionApi(WebApiResource):
raise InternalServerError()
@web_ns.route("/completion-messages/<string:task_id>/stop")
class CompletionStopApi(WebApiResource):
@api.doc("Stop Completion Message")
@api.doc(description="Stop a running completion message task.")
@api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
@api.doc(
@web_ns.doc("Stop Completion Message")
@web_ns.doc(description="Stop a running completion message task.")
@web_ns.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
@ -129,10 +131,11 @@ class CompletionStopApi(WebApiResource):
return {"result": "success"}, 200
@web_ns.route("/chat-messages")
class ChatApi(WebApiResource):
@api.doc("Create Chat Message")
@api.doc(description="Create a chat message for conversational applications.")
@api.doc(
@web_ns.doc("Create Chat Message")
@web_ns.doc(description="Create a chat message for conversational applications.")
@web_ns.doc(
params={
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
"query": {"description": "User query/message", "type": "string", "required": True},
@ -148,7 +151,7 @@ class ChatApi(WebApiResource):
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@api.doc(
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
@ -207,11 +210,12 @@ class ChatApi(WebApiResource):
raise InternalServerError()
@web_ns.route("/chat-messages/<string:task_id>/stop")
class ChatStopApi(WebApiResource):
@api.doc("Stop Chat Message")
@api.doc(description="Stop a running chat message task.")
@api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
@api.doc(
@web_ns.doc("Stop Chat Message")
@web_ns.doc(description="Stop a running chat message task.")
@web_ns.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
@ -229,9 +233,3 @@ class ChatStopApi(WebApiResource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
return {"result": "success"}, 200
api.add_resource(CompletionApi, "/completion-messages")
api.add_resource(CompletionStopApi, "/completion-messages/<string:task_id>/stop")
api.add_resource(ChatApi, "/chat-messages")
api.add_resource(ChatStopApi, "/chat-messages/<string:task_id>/stop")

View File

@ -3,7 +3,7 @@ from flask_restx.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.web import api
from controllers.web import web_ns
from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
@ -16,7 +16,44 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers
from services.web_conversation_service import WebConversationService
@web_ns.route("/conversations")
class ConversationListApi(WebApiResource):
@web_ns.doc("Get Conversation List")
@web_ns.doc(description="Retrieve paginated list of conversations for a chat application.")
@web_ns.doc(
params={
"last_id": {"description": "Last conversation ID for pagination", "type": "string", "required": False},
"limit": {
"description": "Number of conversations to return (1-100)",
"type": "integer",
"required": False,
"default": 20,
},
"pinned": {
"description": "Filter by pinned status",
"type": "string",
"enum": ["true", "false"],
"required": False,
},
"sort_by": {
"description": "Sort order",
"type": "string",
"enum": ["created_at", "-created_at", "updated_at", "-updated_at"],
"required": False,
"default": "-updated_at",
},
}
)
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "App Not Found or Not a Chat App",
500: "Internal Server Error",
}
)
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
@ -57,11 +94,25 @@ class ConversationListApi(WebApiResource):
raise NotFound("Last Conversation Not Exists.")
@web_ns.route("/conversations/<uuid:c_id>")
class ConversationApi(WebApiResource):
delete_response_fields = {
"result": fields.String,
}
@web_ns.doc("Delete Conversation")
@web_ns.doc(description="Delete a specific conversation.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@web_ns.doc(
responses={
204: "Conversation deleted successfully",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Conversation Not Found or Not a Chat App",
500: "Internal Server Error",
}
)
@marshal_with(delete_response_fields)
def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
@ -76,7 +127,32 @@ class ConversationApi(WebApiResource):
return {"result": "success"}, 204
@web_ns.route("/conversations/<uuid:c_id>/name")
class ConversationRenameApi(WebApiResource):
@web_ns.doc("Rename Conversation")
@web_ns.doc(description="Rename a specific conversation with a custom name or auto-generate one.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@web_ns.doc(
params={
"name": {"description": "New conversation name", "type": "string", "required": False},
"auto_generate": {
"description": "Auto-generate conversation name",
"type": "boolean",
"required": False,
"default": False,
},
}
)
@web_ns.doc(
responses={
200: "Conversation renamed successfully",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Conversation Not Found or Not a Chat App",
500: "Internal Server Error",
}
)
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
@ -96,11 +172,25 @@ class ConversationRenameApi(WebApiResource):
raise NotFound("Conversation Not Exists.")
@web_ns.route("/conversations/<uuid:c_id>/pin")
class ConversationPinApi(WebApiResource):
pin_response_fields = {
"result": fields.String,
}
@web_ns.doc("Pin Conversation")
@web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@web_ns.doc(
responses={
200: "Conversation pinned successfully",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Conversation Not Found or Not a Chat App",
500: "Internal Server Error",
}
)
@marshal_with(pin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
@ -117,11 +207,25 @@ class ConversationPinApi(WebApiResource):
return {"result": "success"}
@web_ns.route("/conversations/<uuid:c_id>/unpin")
class ConversationUnPinApi(WebApiResource):
unpin_response_fields = {
"result": fields.String,
}
@web_ns.doc("Unpin Conversation")
@web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@web_ns.doc(
responses={
200: "Conversation unpinned successfully",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Conversation Not Found or Not a Chat App",
500: "Internal Server Error",
}
)
@marshal_with(unpin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
@ -132,10 +236,3 @@ class ConversationUnPinApi(WebApiResource):
WebConversationService.unpin(app_model, conversation_id, end_user)
return {"result": "success"}
api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="web_conversation_name")
api.add_resource(ConversationListApi, "/conversations")
api.add_resource(ConversationApi, "/conversations/<uuid:c_id>")
api.add_resource(ConversationPinApi, "/conversations/<uuid:c_id>/pin")
api.add_resource(ConversationUnPinApi, "/conversations/<uuid:c_id>/unpin")

View File

@ -4,7 +4,7 @@ from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.web import api
from controllers.web import web_ns
from controllers.web.error import (
AppMoreLikeThisDisabledError,
AppSuggestedQuestionsAfterAnswerDisabledError,
@ -38,6 +38,7 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__)
@web_ns.route("/messages")
class MessageListApi(WebApiResource):
message_fields = {
"id": fields.String,
@ -62,6 +63,30 @@ class MessageListApi(WebApiResource):
"data": fields.List(fields.Nested(message_fields)),
}
@web_ns.doc("Get Message List")
@web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.")
@web_ns.doc(
params={
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": True},
"first_id": {"description": "First message ID for pagination", "type": "string", "required": False},
"limit": {
"description": "Number of messages to return (1-100)",
"type": "integer",
"required": False,
"default": 20,
},
}
)
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Conversation Not Found or Not a Chat App",
500: "Internal Server Error",
}
)
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
@ -84,11 +109,36 @@ class MessageListApi(WebApiResource):
raise NotFound("First Message Not Exists.")
@web_ns.route("/messages/<uuid:message_id>/feedbacks")
class MessageFeedbackApi(WebApiResource):
feedback_response_fields = {
"result": fields.String,
}
@web_ns.doc("Create Message Feedback")
@web_ns.doc(description="Submit feedback (like/dislike) for a specific message.")
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
@web_ns.doc(
params={
"rating": {
"description": "Feedback rating",
"type": "string",
"enum": ["like", "dislike"],
"required": False,
},
"content": {"description": "Feedback content/comment", "type": "string", "required": False},
}
)
@web_ns.doc(
responses={
200: "Feedback submitted successfully",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Message Not Found",
500: "Internal Server Error",
}
)
@marshal_with(feedback_response_fields)
def post(self, app_model, end_user, message_id):
message_id = str(message_id)
@ -112,7 +162,31 @@ class MessageFeedbackApi(WebApiResource):
return {"result": "success"}
@web_ns.route("/messages/<uuid:message_id>/more-like-this")
class MessageMoreLikeThisApi(WebApiResource):
@web_ns.doc("Generate More Like This")
@web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).")
@web_ns.doc(
params={
"message_id": {"description": "Message UUID", "type": "string", "required": True},
"response_mode": {
"description": "Response mode",
"type": "string",
"enum": ["blocking", "streaming"],
"required": True,
},
}
)
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request - Not a completion app or feature disabled",
401: "Unauthorized",
403: "Forbidden",
404: "Message Not Found",
500: "Internal Server Error",
}
)
def get(self, app_model, end_user, message_id):
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -156,11 +230,25 @@ class MessageMoreLikeThisApi(WebApiResource):
raise InternalServerError()
@web_ns.route("/messages/<uuid:message_id>/suggested-questions")
class MessageSuggestedQuestionApi(WebApiResource):
suggested_questions_response_fields = {
"data": fields.List(fields.String),
}
@web_ns.doc("Get Suggested Questions")
@web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).")
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request - Not a chat app or feature disabled",
401: "Unauthorized",
403: "Forbidden",
404: "Message Not Found or Conversation Not Found",
500: "Internal Server Error",
}
)
@marshal_with(suggested_questions_response_fields)
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
@ -192,9 +280,3 @@ class MessageSuggestedQuestionApi(WebApiResource):
raise InternalServerError()
return {"data": questions}
api.add_resource(MessageListApi, "/messages")
api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks")
api.add_resource(MessageMoreLikeThisApi, "/messages/<uuid:message_id>/more-like-this")
api.add_resource(MessageSuggestedQuestionApi, "/messages/<uuid:message_id>/suggested-questions")

View File

@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from werkzeug.exceptions import NotFound
from controllers.web import api
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import message_file_fields
@ -23,6 +23,7 @@ message_fields = {
}
@web_ns.route("/saved-messages")
class SavedMessageListApi(WebApiResource):
saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
@ -34,6 +35,29 @@ class SavedMessageListApi(WebApiResource):
"result": fields.String,
}
@web_ns.doc("Get Saved Messages")
@web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.")
@web_ns.doc(
params={
"last_id": {"description": "Last message ID for pagination", "type": "string", "required": False},
"limit": {
"description": "Number of messages to return (1-100)",
"type": "integer",
"required": False,
"default": 20,
},
}
)
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request - Not a completion app",
401: "Unauthorized",
403: "Forbidden",
404: "App Not Found",
500: "Internal Server Error",
}
)
@marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != "completion":
@ -46,6 +70,23 @@ class SavedMessageListApi(WebApiResource):
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
@web_ns.doc("Save Message")
@web_ns.doc(description="Save a specific message for later reference.")
@web_ns.doc(
params={
"message_id": {"description": "Message UUID to save", "type": "string", "required": True},
}
)
@web_ns.doc(
responses={
200: "Message saved successfully",
400: "Bad Request - Not a completion app",
401: "Unauthorized",
403: "Forbidden",
404: "Message Not Found",
500: "Internal Server Error",
}
)
@marshal_with(post_response_fields)
def post(self, app_model, end_user):
if app_model.mode != "completion":
@ -63,11 +104,25 @@ class SavedMessageListApi(WebApiResource):
return {"result": "success"}
@web_ns.route("/saved-messages/<uuid:message_id>")
class SavedMessageApi(WebApiResource):
delete_response_fields = {
"result": fields.String,
}
@web_ns.doc("Delete Saved Message")
@web_ns.doc(description="Remove a message from saved messages.")
@web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}})
@web_ns.doc(
responses={
204: "Message removed successfully",
400: "Bad Request - Not a completion app",
401: "Unauthorized",
403: "Forbidden",
404: "Message Not Found",
500: "Internal Server Error",
}
)
@marshal_with(delete_response_fields)
def delete(self, app_model, end_user, message_id):
message_id = str(message_id)
@ -78,7 +133,3 @@ class SavedMessageApi(WebApiResource):
SavedMessageService.delete(app_model, end_user, message_id)
return {"result": "success"}, 204
api.add_resource(SavedMessageListApi, "/saved-messages")
api.add_resource(SavedMessageApi, "/saved-messages/<uuid:message_id>")

View File

@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.web import api
from controllers.web import web_ns
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from libs.helper import AppIconUrlField
@ -11,6 +11,7 @@ from models.model import Site
from services.feature_service import FeatureService
@web_ns.route("/site")
class AppSiteApi(WebApiResource):
"""Resource for app sites."""
@ -53,9 +54,9 @@ class AppSiteApi(WebApiResource):
"custom_config": fields.Raw(attribute="custom_config"),
}
@api.doc("Get App Site Info")
@api.doc(description="Retrieve app site information and configuration.")
@api.doc(
@web_ns.doc("Get App Site Info")
@web_ns.doc(description="Retrieve app site information and configuration.")
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
@ -82,9 +83,6 @@ class AppSiteApi(WebApiResource):
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
api.add_resource(AppSiteApi, "/site")
class AppSiteInfo:
"""Class to store site information."""

View File

@ -3,7 +3,7 @@ import logging
from flask_restx import reqparse
from werkzeug.exceptions import InternalServerError
from controllers.web import api
from controllers.web import web_ns
from controllers.web.error import (
CompletionRequestError,
NotWorkflowAppError,
@ -30,16 +30,17 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
@web_ns.route("/workflows/run")
class WorkflowRunApi(WebApiResource):
@api.doc("Run Workflow")
@api.doc(description="Execute a workflow with provided inputs and files.")
@api.doc(
@web_ns.doc("Run Workflow")
@web_ns.doc(description="Execute a workflow with provided inputs and files.")
@web_ns.doc(
params={
"inputs": {"description": "Input variables for the workflow", "type": "object", "required": True},
"files": {"description": "Files to be processed by the workflow", "type": "array", "required": False},
}
)
@api.doc(
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
@ -85,15 +86,16 @@ class WorkflowRunApi(WebApiResource):
raise InternalServerError()
@web_ns.route("/workflows/tasks/<string:task_id>/stop")
class WorkflowTaskStopApi(WebApiResource):
@api.doc("Stop Workflow Task")
@api.doc(description="Stop a running workflow task.")
@api.doc(
@web_ns.doc("Stop Workflow Task")
@web_ns.doc(description="Stop a running workflow task.")
@web_ns.doc(
params={
"task_id": {"description": "Task ID to stop", "type": "string", "required": True},
}
)
@api.doc(
@web_ns.doc(
responses={
200: "Success",
400: "Bad Request",
@ -119,7 +121,3 @@ class WorkflowTaskStopApi(WebApiResource):
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}
api.add_resource(WorkflowRunApi, "/workflows/run")
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")

View File

@ -32,11 +32,16 @@ class TokenBufferMemory:
self.model_instance = model_instance
def _build_prompt_message_with_files(
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
self,
message_files: Sequence[MessageFile],
text_content: str,
message: Message,
app_record,
is_user_message: bool,
) -> PromptMessage:
"""
Build prompt message with files.
:param message_files: list of MessageFile objects
:param message_files: Sequence of MessageFile objects
:param text_content: text content of the message
:param message: Message object
:param app_record: app record
@ -128,14 +133,12 @@ class TokenBufferMemory:
prompt_messages: list[PromptMessage] = []
for message in messages:
# Process user message with files
user_files = (
db.session.query(MessageFile)
.where(
user_files = db.session.scalars(
select(MessageFile).where(
MessageFile.message_id == message.id,
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
)
.all()
)
).all()
if user_files:
user_prompt_message = self._build_prompt_message_with_files(
@ -150,11 +153,9 @@ class TokenBufferMemory:
prompt_messages.append(UserPromptMessage(content=message.query))
# Process assistant message with files
assistant_files = (
db.session.query(MessageFile)
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
.all()
)
assistant_files = db.session.scalars(
select(MessageFile).where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
).all()
if assistant_files:
assistant_prompt_message = self._build_prompt_message_with_files(

View File

@ -15,6 +15,7 @@ from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
from sqlalchemy import select
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
@ -699,8 +700,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
def _get_workflow_nodes(self, workflow_run_id: str):
"""Helper method to get workflow nodes"""
workflow_nodes = (
db.session.query(
workflow_nodes = db.session.scalars(
select(
WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
@ -713,10 +714,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
)
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.all()
)
).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
).all()
return workflow_nodes
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:

View File

@ -1,5 +1,6 @@
import time
import uuid
from collections.abc import Sequence
import requests
from requests.auth import HTTPDigestAuth
@ -139,7 +140,7 @@ class TidbService:
@staticmethod
def batch_update_tidb_serverless_cluster_status(
tidb_serverless_list: list[TidbAuthBinding],
tidb_serverless_list: Sequence[TidbAuthBinding],
project_id: str,
api_url: str,
iam_url: str,

View File

@ -1,4 +1,5 @@
from pydantic import Field
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool_provider import ToolProviderController
@ -176,11 +177,11 @@ class ApiToolProviderController(ToolProviderController):
tools: list[ApiTool] = []
# get tenant api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.all()
)
db_providers = db.session.scalars(
select(ApiToolProvider).where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name
)
).all()
if db_providers and len(db_providers) != 0:
for db_provider in db_providers:

View File

@ -87,9 +87,7 @@ class ToolLabelManager:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
labels: list[ToolLabelBinding] = (
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all()
)
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}

View File

@ -671,9 +671,9 @@ class ToolManager:
# get db api providers
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
)
db_api_providers = db.session.scalars(
select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)
).all()
api_provider_controllers: list[dict[str, Any]] = [
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
@ -694,9 +694,9 @@ class ToolManager:
if "workflow" in filters:
# get workflow providers
workflow_providers: list[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
)
workflow_providers = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
workflow_provider_controllers: list[WorkflowToolProviderController] = []
for workflow_provider in workflow_providers:

View File

@ -1,3 +1,5 @@
from sqlalchemy import select
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from models.dataset import AppDatasetJoin
@ -13,7 +15,7 @@ def handle(sender, **kwargs):
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all()
removed_dataset_ids: set[str] = set()
if not app_dataset_joins:

View File

@ -1,5 +1,7 @@
from typing import cast
from sqlalchemy import select
from core.workflow.nodes import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from events.app_event import app_published_workflow_was_updated
@ -15,7 +17,7 @@ def handle(sender, **kwargs):
published_workflow = cast(Workflow, published_workflow)
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all()
removed_dataset_ids: set[str] = set()
if not app_dataset_joins:

View File

@ -7,6 +7,7 @@ import sqlalchemy as sa
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
from typing_extensions import deprecated
from models.base import Base
@ -187,7 +188,28 @@ class Account(UserMixin, Base):
return TenantAccountRole.is_admin_role(self.role)
@property
@deprecated("Use has_edit_permission instead.")
def is_editor(self):
"""Determines if the account has edit permissions in their current tenant (workspace).
This property checks if the current role has editing privileges, which includes:
- `OWNER`
- `ADMIN`
- `EDITOR`
Note: This checks for any role with editing permission, not just the 'EDITOR' role specifically.
"""
return self.has_edit_permission
@property
def has_edit_permission(self):
"""Determines if the account has editing permissions in their current tenant (workspace).
This property checks if the current role has editing privileges, which includes:
- `OWNER`
- `ADMIN`
- `EDITOR`
"""
return TenantAccountRole.is_editing_role(self.role)
@property
@ -218,10 +240,12 @@ class Tenant(Base):
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
def get_accounts(self) -> list[Account]:
return (
db.session.query(Account)
.where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id)
.all()
return list(
db.session.scalars(
select(Account).where(
Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id
)
).all()
)
@property

View File

@ -208,7 +208,9 @@ class Dataset(Base):
@property
def doc_metadata(self):
dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all()
dataset_metadatas = db.session.scalars(
select(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id)
).all()
doc_metadata = [
{
@ -1055,13 +1057,11 @@ class ExternalKnowledgeApis(Base):
@property
def dataset_bindings(self) -> list[dict[str, Any]]:
external_knowledge_bindings = (
db.session.query(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
.all()
)
external_knowledge_bindings = db.session.scalars(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
).all()
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
dataset_bindings: list[dict[str, Any]] = []
for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name})

View File

@ -811,7 +811,7 @@ class Conversation(Base):
@property
def status_count(self):
messages = db.session.query(Message).where(Message.conversation_id == self.id).all()
messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all()
status_counts = {
WorkflowExecutionStatus.RUNNING: 0,
WorkflowExecutionStatus.SUCCEEDED: 0,
@ -1090,7 +1090,7 @@ class Message(Base):
@property
def feedbacks(self):
feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all()
feedbacks = db.session.scalars(select(MessageFeedback).where(MessageFeedback.message_id == self.id)).all()
return feedbacks
@property
@ -1145,7 +1145,7 @@ class Message(Base):
def message_files(self) -> list[dict[str, Any]]:
from factories import file_factory
message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all()
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
current_app = db.session.query(App).where(App.id == self.app_id).first()
if not current_app:
raise ValueError(f"App {self.app_id} not found")

View File

@ -96,11 +96,11 @@ def clean_unused_datasets_task():
break
for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
.where(DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id)
.all()
)
dataset_query = db.session.scalars(
select(DatasetQuery).where(
DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id
)
).all()
if not dataset_query or len(dataset_query) == 0:
try:
@ -121,15 +121,13 @@ def clean_unused_datasets_task():
if should_clean:
# Add auto disable log if required
if add_logs:
documents = (
db.session.query(Document)
.where(
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
).all()
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,

View File

@ -3,6 +3,7 @@ import time
from collections import defaultdict
import click
from sqlalchemy import select
import app
from configs import dify_config
@ -31,9 +32,9 @@ def mail_clean_document_notify_task():
# send document clean notify mail
try:
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False).all()
)
dataset_auto_disable_logs = db.session.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False)
).all()
# group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs:

View File

@ -1,6 +1,8 @@
import time
from collections.abc import Sequence
import click
from sqlalchemy import select
import app
from configs import dify_config
@ -15,11 +17,9 @@ def update_tidb_serverless_status_task():
start_at = time.perf_counter()
try:
# check the number of idle tidb serverless
tidb_serverless_list = (
db.session.query(TidbAuthBinding)
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
.all()
)
tidb_serverless_list = db.session.scalars(
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
).all()
if len(tidb_serverless_list) == 0:
return
# update tidb serverless status
@ -32,7 +32,7 @@ def update_tidb_serverless_status_task():
click.echo(click.style(f"Update tidb serverless status task success latency: {end_at - start_at}", fg="green"))
def update_clusters(tidb_serverless_list: list[TidbAuthBinding]):
def update_clusters(tidb_serverless_list: Sequence[TidbAuthBinding]):
try:
# batch 20
for i in range(0, len(tidb_serverless_list), 20):

View File

@ -246,6 +246,8 @@ class AccountService:
account.name = name
if password:
valid_password(password)
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()

View File

@ -263,11 +263,9 @@ class AppAnnotationService:
db.session.delete(annotation)
annotation_hit_histories = (
db.session.query(AppAnnotationHitHistory)
.where(AppAnnotationHitHistory.annotation_id == annotation_id)
.all()
)
annotation_hit_histories = db.session.scalars(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id)
).all()
if annotation_hit_histories:
for annotation_hit_history in annotation_hit_histories:
db.session.delete(annotation_hit_history)

View File

@ -1,5 +1,7 @@
import json
from sqlalchemy import select
from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
@ -9,11 +11,11 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str):
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
.all()
)
data_source_api_key_bindings = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
)
).all()
return data_source_api_key_bindings
@staticmethod

View File

@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
import click
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@ -115,7 +116,7 @@ class ClearFreePlanTenantExpiredLogs:
@classmethod
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
with flask_app.app_context():
apps = db.session.query(App).where(App.tenant_id == tenant_id).all()
apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
app_ids = [app.id for app in apps]
while True:
with Session(db.engine).no_autoflush as session:

View File

@ -6,6 +6,7 @@ import secrets
import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal, Optional
import sqlalchemy as sa
@ -741,14 +742,12 @@ class DatasetService:
}
# get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog)
.where(
dataset_auto_disable_logs = db.session.scalars(
select(DatasetAutoDisableLog).where(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
)
.all()
)
).all()
if dataset_auto_disable_logs:
return {
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
@ -885,69 +884,58 @@ class DocumentService:
return document
@staticmethod
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.id.in_(document_ids),
Document.enabled == True,
Document.indexing_status == "completed",
Document.archived == False,
)
.all()
)
).all()
return documents
@staticmethod
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id,
Document.enabled == True,
)
.all()
)
).all()
return documents
@staticmethod
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id,
Document.enabled == True,
Document.indexing_status == "completed",
Document.archived == False,
)
.all()
)
).all()
return documents
@staticmethod
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
.all()
)
def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
).all()
return documents
@staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]:
assert isinstance(current_user, Account)
documents = (
db.session.query(Document)
.where(
documents = db.session.scalars(
select(Document).where(
Document.batch == batch,
Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id,
)
.all()
)
).all()
return documents
@ -984,7 +972,7 @@ class DocumentService:
# Check if document_ids is not empty to avoid WHERE false condition
if not document_ids or len(document_ids) == 0:
return
documents = db.session.query(Document).where(Document.id.in_(document_ids)).all()
documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all()
file_ids = [
document.data_source_info_dict["upload_file_id"]
for document in documents
@ -2424,16 +2412,14 @@ class SegmentService:
if not segment_ids or len(segment_ids) == 0:
return
if action == "enable":
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == False,
)
.all()
)
).all()
if not segments:
return
real_deal_segment_ids = []
@ -2451,16 +2437,14 @@ class SegmentService:
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
elif action == "disable":
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == True,
)
.all()
)
).all()
if not segments:
return
real_deal_segment_ids = []
@ -2532,16 +2516,13 @@ class SegmentService:
dataset: Dataset,
) -> list[ChildChunk]:
assert isinstance(current_user, Account)
child_chunks = (
db.session.query(ChildChunk)
.where(
child_chunks = db.session.scalars(
select(ChildChunk).where(
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
)
.all()
)
).all()
child_chunks_map = {chunk.id: chunk for chunk in child_chunks}
new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], []
@ -2751,19 +2732,13 @@ class DatasetCollectionBindingService:
class DatasetPermissionService:
@classmethod
def get_dataset_partial_member_list(cls, dataset_id):
user_list_query = (
db.session.query(
user_list_query = db.session.scalars(
select(
DatasetPermission.account_id,
)
.where(DatasetPermission.dataset_id == dataset_id)
.all()
)
).where(DatasetPermission.dataset_id == dataset_id)
).all()
user_list = []
for user in user_list_query:
user_list.append(user.account_id)
return user_list
return user_list_query
@classmethod
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):

View File

@ -3,7 +3,7 @@ import logging
from json import JSONDecodeError
from typing import Optional, Union
from sqlalchemy import or_
from sqlalchemy import or_, select
from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration
@ -322,16 +322,14 @@ class ModelLoadBalancingService:
if not isinstance(configs, list):
raise ValueError("Invalid load balancing configs")
current_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.where(
current_load_balancing_configs = db.session.scalars(
select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.all()
)
).all()
# id as key, config as value
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}

View File

@ -1,5 +1,7 @@
from typing import Optional
from sqlalchemy import select
from constants.languages import languages
from extensions.ext_database import db
from models.model import App, RecommendedApp
@ -31,18 +33,14 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
:param language: language
:return:
"""
recommended_apps = (
db.session.query(RecommendedApp)
.where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
.all()
)
recommended_apps = db.session.scalars(
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
).all()
if len(recommended_apps) == 0:
recommended_apps = (
db.session.query(RecommendedApp)
.where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
.all()
)
recommended_apps = db.session.scalars(
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
).all()
categories = set()
recommended_apps_result = []

View File

@ -2,7 +2,7 @@ import uuid
from typing import Optional
from flask_login import current_user
from sqlalchemy import func
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
@ -29,35 +29,30 @@ class TagService:
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
tags = (
db.session.query(Tag)
.where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
)
tags = db.session.scalars(
select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
if not tags:
return []
tag_ids = [tag.id for tag in tags]
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
tag_bindings = (
db.session.query(TagBinding.target_id)
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.all()
)
if not tag_bindings:
return []
results = [tag_binding.target_id for tag_binding in tag_bindings]
return results
tag_bindings = db.session.scalars(
select(TagBinding.target_id).where(
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id
)
).all()
return tag_bindings
@staticmethod
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
if not tag_type or not tag_name:
return []
tags = (
db.session.query(Tag)
.where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
tags = list(
db.session.scalars(
select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
)
if not tags:
return []
@ -117,7 +112,7 @@ class TagService:
raise NotFound("Tag not found")
db.session.delete(tag)
# delete tag binding
tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all()
tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping
from typing import Any, cast
from httpx import get
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
@ -443,9 +444,7 @@ class ApiToolManageService:
list api tools
"""
# get all api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
)
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
result: list[ToolProviderApiEntity] = []

View File

@ -3,7 +3,7 @@ from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import or_
from sqlalchemy import or_, select
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
@ -186,7 +186,9 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
db_tools = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
tools: list[WorkflowToolProviderController] = []
for provider in db_tools:

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
@ -39,7 +40,7 @@ def enable_annotation_reply_task(
db.session.close()
return
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all()
annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
@ -34,7 +35,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
if not dataset:
raise Exception("Document has no dataset")
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@ -59,7 +62,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
db.session.commit()
if file_ids:
files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
try:
storage.delete(file.key)

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
@ -55,8 +56,8 @@ def clean_dataset_task(
index_struct=index_struct,
collection_binding_id=collection_binding_id,
)
documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all()
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled

View File

@ -4,6 +4,7 @@ from typing import Optional
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
@ -35,7 +36,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
if not dataset:
raise Exception("Document has no dataset")
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@ -34,7 +35,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
document = db.session.query(Document).where(Document.id == document_id).first()
db.session.delete(document)
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

View File

@ -4,6 +4,7 @@ from typing import Literal
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -36,16 +37,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
elif action == "add":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
@ -89,16 +88,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
)
db.session.commit()
elif action == "update":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()
# add new index
if dataset_documents:
# update document status

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@ -44,15 +45,13 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
).all()
if not segments:
db.session.close()

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
@ -85,7 +86,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -45,7 +46,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
@ -79,7 +80,9 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -45,15 +46,13 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
).all()
if not segments:
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
db.session.close()

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@ -45,7 +46,7 @@ def remove_document_from_index_task(document_id: str):
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
try:

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -69,7 +70,9 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -63,7 +64,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index

View File

@ -0,0 +1,101 @@
"""Integration tests for ChatMessageApi permission verification."""
import uuid
from unittest import mock
import pytest
from flask.testing import FlaskClient
from controllers.console.app import completion as completion_api
from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import Account, App, Tenant
from models.account import TenantAccountRole
from models.model import AppMode
from services.app_generate_service import AppGenerateService
class TestChatMessageApiPermissions:
"""Test permission verification for ChatMessageApi endpoint."""
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model for testing."""
app = App()
app.id = str(uuid.uuid4())
app.mode = AppMode.CHAT.value
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
return app
@pytest.fixture
def mock_account(self):
"""Create a mock Account for testing."""
account = Account()
account.id = str(uuid.uuid4())
account.name = "Test User"
account.email = "test@example.com"
account.last_active_at = naive_utc_now()
account.created_at = naive_utc_now()
account.updated_at = naive_utc_now()
# Create mock tenant
tenant = Tenant()
tenant.id = str(uuid.uuid4())
tenant.name = "Test Tenant"
account._current_tenant = tenant
return account
@pytest.mark.parametrize(
("role", "status"),
[
(TenantAccountRole.OWNER, 200),
(TenantAccountRole.ADMIN, 200),
(TenantAccountRole.EDITOR, 200),
(TenantAccountRole.NORMAL, 403),
(TenantAccountRole.DATASET_OPERATOR, 403),
],
)
def test_post_with_owner_role_succeeds(
self,
test_client: FlaskClient,
auth_header,
monkeypatch,
mock_app_model,
mock_account,
role: TenantAccountRole,
status: int,
):
"""Test that OWNER role can access chat-messages endpoint."""
"""Setup common mocks for testing."""
# Mock app loading
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
# Mock current user
monkeypatch.setattr(completion_api, "current_user", mock_account)
mock_generate = mock.Mock(return_value={"message": "Test response"})
monkeypatch.setattr(AppGenerateService, "generate", mock_generate)
# Set user role to OWNER
mock_account.role = role
response = test_client.post(
f"/console/api/apps/{mock_app_model.id}/chat-messages",
headers=auth_header,
json={
"inputs": {},
"query": "Hello, world!",
"model_config": {
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}}
},
"response_mode": "blocking",
},
)
assert response.status_code == status

View File

@ -0,0 +1,129 @@
"""Integration tests for ModelConfigResource permission verification."""
import uuid
from unittest import mock
import pytest
from flask.testing import FlaskClient
from controllers.console.app import model_config as model_config_api
from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import Account, App, Tenant
from models.account import TenantAccountRole
from models.model import AppMode
from services.app_model_config_service import AppModelConfigService
class TestModelConfigResourcePermissions:
"""Test permission verification for ModelConfigResource endpoint."""
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model for testing."""
app = App()
app.id = str(uuid.uuid4())
app.mode = AppMode.CHAT.value
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
app.app_model_config_id = str(uuid.uuid4())
return app
@pytest.fixture
def mock_account(self):
"""Create a mock Account for testing."""
account = Account()
account.id = str(uuid.uuid4())
account.name = "Test User"
account.email = "test@example.com"
account.last_active_at = naive_utc_now()
account.created_at = naive_utc_now()
account.updated_at = naive_utc_now()
# Create mock tenant
tenant = Tenant()
tenant.id = str(uuid.uuid4())
tenant.name = "Test Tenant"
account._current_tenant = tenant
return account
@pytest.mark.parametrize(
("role", "status"),
[
(TenantAccountRole.OWNER, 200),
(TenantAccountRole.ADMIN, 200),
(TenantAccountRole.EDITOR, 200),
(TenantAccountRole.NORMAL, 403),
(TenantAccountRole.DATASET_OPERATOR, 403),
],
)
def test_post_with_owner_role_succeeds(
self,
test_client: FlaskClient,
auth_header,
monkeypatch,
mock_app_model,
mock_account,
role: TenantAccountRole,
status: int,
):
"""Test that OWNER role can access model-config endpoint."""
# Set user role to OWNER
mock_account.role = role
# Mock app loading
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
# Mock current user
monkeypatch.setattr(model_config_api, "current_user", mock_account)
# Mock AccountService.load_user to prevent authentication issues
from services.account_service import AccountService
mock_load_user = mock.Mock(return_value=mock_account)
monkeypatch.setattr(AccountService, "load_user", mock_load_user)
mock_validate_config = mock.Mock(
return_value={
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}},
"pre_prompt": "You are a helpful assistant.",
"user_input_form": [],
"dataset_query_variable": "",
"agent_mode": {"enabled": False, "tools": []},
}
)
monkeypatch.setattr(AppModelConfigService, "validate_configuration", mock_validate_config)
# Mock database operations
mock_db_session = mock.Mock()
mock_db_session.add = mock.Mock()
mock_db_session.flush = mock.Mock()
mock_db_session.commit = mock.Mock()
monkeypatch.setattr(model_config_api.db, "session", mock_db_session)
# Mock app_model_config_was_updated event
mock_event = mock.Mock()
mock_event.send = mock.Mock()
monkeypatch.setattr(model_config_api, "app_model_config_was_updated", mock_event)
response = test_client.post(
f"/console/api/apps/{mock_app_model.id}/model-config",
headers=auth_header,
json={
"model": {
"provider": "openai",
"name": "gpt-4",
"mode": "chat",
"completion_params": {"temperature": 0.7, "max_tokens": 1000},
},
"user_input_form": [],
"dataset_query_variable": "",
"pre_prompt": "You are a helpful assistant.",
"agent_mode": {"enabled": False, "tools": []},
},
)
assert response.status_code == status

View File

@ -91,6 +91,28 @@ class TestAccountService:
assert account.password is None
assert account.password_salt is None
def test_create_account_password_invalid_new_password(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test account create with invalid new password format.
"""
fake = Faker()
email = fake.email()
name = fake.name()
# Setup mocks
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
# Test with too short password (assuming minimum length validation)
with pytest.raises(ValueError): # Password validation error
AccountService.create_account(
email=email,
name=name,
interface_language="en-US",
password="invalid_new_password",
)
def test_create_account_registration_disabled(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test account creation when registration is disabled.
@ -940,7 +962,8 @@ class TestAccountService:
Test getting user through non-existent email.
"""
fake = Faker()
non_existent_email = fake.email()
domain = f"test-{fake.random_letters(10)}.com"
non_existent_email = fake.email(domain=domain)
found_user = AccountService.get_user_through_email(non_existent_email)
assert found_user is None

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import select
from models.account import TenantAccountJoin, TenantAccountRole
from models.model import Account, Tenant
@ -468,7 +469,7 @@ class TestModelLoadBalancingService:
assert load_balancing_config.id is not None
# Verify inherit config was created in database
inherit_configs = (
db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all()
)
inherit_configs = db.session.scalars(
select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__")
).all()
assert len(inherit_configs) == 1

View File

@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -954,7 +955,9 @@ class TestTagService:
from extensions.ext_database import db
# Verify only one binding exists
bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all()
bindings = db.session.scalars(
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
).all()
assert len(bindings) == 1
def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies):
@ -1064,7 +1067,9 @@ class TestTagService:
# No error should be raised, and database state should remain unchanged
from extensions.ext_database import db
bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all()
bindings = db.session.scalars(
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
).all()
assert len(bindings) == 0
def test_check_target_exists_knowledge_success(

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from models.account import Account
@ -354,16 +355,14 @@ class TestWebConversationService:
# Verify only one pinned conversation record exists
from extensions.ext_database import db
pinned_conversations = (
db.session.query(PinnedConversation)
.where(
pinned_conversations = db.session.scalars(
select(PinnedConversation).where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
PinnedConversation.created_by_role == "account",
PinnedConversation.created_by == account.id,
)
.all()
)
).all()
assert len(pinned_conversations) == 1

View File

@ -246,6 +246,43 @@ class TestEmailI18nService:
sent_email = mock_sender.sent_emails[0]
assert sent_email["subject"] == "Reset Your Dify Password"
def test_subject_format_keyerror_fallback_path(
self,
mock_renderer: MockEmailRenderer,
mock_sender: MockEmailSender,
):
"""Trigger subject KeyError and cover except branch."""
# Config with subject that references an unknown key (no {application_title} to avoid second format)
config = EmailI18nConfig(
templates={
EmailType.INVITE_MEMBER: {
EmailLanguage.EN_US: EmailTemplate(
subject="Invite: {unknown_placeholder}",
template_path="invite_member_en.html",
branded_template_path="branded/invite_member_en.html",
),
}
}
)
branding_service = MockBrandingService(enabled=False)
service = EmailI18nService(
config=config,
renderer=mock_renderer,
branding_service=branding_service,
sender=mock_sender,
)
# Will raise KeyError on subject.format(**full_context), then hit except branch and skip fallback
service.send_email(
email_type=EmailType.INVITE_MEMBER,
language_code="en-US",
to="test@example.com",
)
assert len(mock_sender.sent_emails) == 1
# Subject is left unformatted due to KeyError fallback path without application_title
assert mock_sender.sent_emails[0]["subject"] == "Invite: {unknown_placeholder}"
def test_send_change_email_old_phase(
self,
email_config: EmailI18nConfig,

View File

@ -0,0 +1,122 @@
from flask import Blueprint, Flask
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, Unauthorized
from core.errors.error import AppInvokeQuotaExceededError
from libs.external_api import ExternalApi
def _create_api_app():
app = Flask(__name__)
bp = Blueprint("t", __name__)
api = ExternalApi(bp)
@api.route("/bad-request")
class Bad(Resource): # type: ignore
def get(self): # type: ignore
raise BadRequest("invalid input")
@api.route("/unauth")
class Unauth(Resource): # type: ignore
def get(self): # type: ignore
raise Unauthorized("auth required")
@api.route("/value-error")
class ValErr(Resource): # type: ignore
def get(self): # type: ignore
raise ValueError("boom")
@api.route("/quota")
class Quota(Resource): # type: ignore
def get(self): # type: ignore
raise AppInvokeQuotaExceededError("quota exceeded")
@api.route("/general")
class Gen(Resource): # type: ignore
def get(self): # type: ignore
raise RuntimeError("oops")
# Note: We avoid altering default_mediatype to keep normal error paths
# Special 400 message rewrite
@api.route("/json-empty")
class JsonEmpty(Resource): # type: ignore
def get(self): # type: ignore
e = BadRequest()
# Force the specific message the handler rewrites
e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
raise e
# 400 mapping payload path
@api.route("/param-errors")
class ParamErrors(Resource): # type: ignore
def get(self): # type: ignore
e = BadRequest()
# Coerce a mapping description to trigger param error shaping
e.description = {"field": "is required"} # type: ignore[assignment]
raise e
app.register_blueprint(bp, url_prefix="/api")
return app
def test_external_api_error_handlers_basic_paths():
app = _create_api_app()
client = app.test_client()
# 400
res = client.get("/api/bad-request")
assert res.status_code == 400
data = res.get_json()
assert data["code"] == "bad_request"
assert data["status"] == 400
# 401
res = client.get("/api/unauth")
assert res.status_code == 401
assert "WWW-Authenticate" in res.headers
# 400 ValueError
res = client.get("/api/value-error")
assert res.status_code == 400
assert res.get_json()["code"] == "invalid_param"
# 500 general
res = client.get("/api/general")
assert res.status_code == 500
assert res.get_json()["status"] == 500
def test_external_api_json_message_and_bad_request_rewrite():
app = _create_api_app()
client = app.test_client()
# JSON empty special rewrite
res = client.get("/api/json-empty")
assert res.status_code == 400
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
def test_external_api_param_mapping_and_quota_and_exc_info_none():
# Force exc_info() to return (None,None,None) only during request
import libs.external_api as ext
orig_exc_info = ext.sys.exc_info
try:
ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment]
app = _create_api_app()
client = app.test_client()
# Param errors mapping payload path
res = client.get("/api/param-errors")
assert res.status_code == 400
data = res.get_json()
assert data["code"] == "invalid_param"
assert data["params"] == "field"
# Quota path — depending on Flask-RESTX internals it may be handled
res = client.get("/api/quota")
assert res.status_code in (400, 429)
finally:
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]

View File

@ -0,0 +1,19 @@
import pytest
from libs.oauth import OAuth
def test_oauth_base_methods_raise_not_implemented():
oauth = OAuth(client_id="id", client_secret="sec", redirect_uri="uri")
with pytest.raises(NotImplementedError):
oauth.get_authorization_url()
with pytest.raises(NotImplementedError):
oauth.get_access_token("code")
with pytest.raises(NotImplementedError):
oauth.get_raw_user_info("token")
with pytest.raises(NotImplementedError):
oauth._transform_user_info({}) # type: ignore[name-defined]

View File

@ -0,0 +1,53 @@
from unittest.mock import MagicMock, patch
import pytest
from python_http_client.exceptions import UnauthorizedError
from libs.sendgrid import SendGridClient
def _mail(to: str = "user@example.com") -> dict:
return {"to": to, "subject": "Hi", "html": "<b>Hi</b>"}
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
def test_sendgrid_success(mock_client_cls: MagicMock):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
# nested attribute access: client.mail.send.post
mock_client.client.mail.send.post.return_value = MagicMock(status_code=202, body=b"", headers={})
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
sg.send(_mail())
mock_client_cls.assert_called_once()
mock_client.client.mail.send.post.assert_called_once()
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
def test_sendgrid_missing_to_raises(mock_client_cls: MagicMock):
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
with pytest.raises(ValueError):
sg.send(_mail(to=""))
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
def test_sendgrid_auth_errors_reraise(mock_client_cls: MagicMock):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.client.mail.send.post.side_effect = UnauthorizedError(401, "Unauthorized", b"{}", {})
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
with pytest.raises(UnauthorizedError):
sg.send(_mail())
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
def test_sendgrid_timeout_reraise(mock_client_cls: MagicMock):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.client.mail.send.post.side_effect = TimeoutError("timeout")
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
with pytest.raises(TimeoutError):
sg.send(_mail())

View File

@ -0,0 +1,100 @@
from unittest.mock import MagicMock, patch
import pytest
from libs.smtp import SMTPClient
def _mail() -> dict:
return {"to": "user@example.com", "subject": "Hi", "html": "<b>Hi</b>"}
@patch("libs.smtp.smtplib.SMTP")
def test_smtp_plain_success(mock_smtp_cls: MagicMock):
mock_smtp = MagicMock()
mock_smtp_cls.return_value = mock_smtp
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
client.send(_mail())
mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10)
mock_smtp.sendmail.assert_called_once()
mock_smtp.quit.assert_called_once()
@patch("libs.smtp.smtplib.SMTP")
def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock):
mock_smtp = MagicMock()
mock_smtp_cls.return_value = mock_smtp
client = SMTPClient(
server="smtp.example.com",
port=587,
username="user",
password="pass",
_from="noreply@example.com",
use_tls=True,
opportunistic_tls=True,
)
client.send(_mail())
mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10)
assert mock_smtp.ehlo.call_count == 2
mock_smtp.starttls.assert_called_once()
mock_smtp.login.assert_called_once_with("user", "pass")
mock_smtp.sendmail.assert_called_once()
mock_smtp.quit.assert_called_once()
@patch("libs.smtp.smtplib.SMTP_SSL")
def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock):
# Cover SMTP_SSL branch and TimeoutError handling
mock_smtp = MagicMock()
mock_smtp.sendmail.side_effect = TimeoutError("timeout")
mock_smtp_ssl_cls.return_value = mock_smtp
client = SMTPClient(
server="smtp.example.com",
port=465,
username="",
password="",
_from="noreply@example.com",
use_tls=True,
opportunistic_tls=False,
)
with pytest.raises(TimeoutError):
client.send(_mail())
mock_smtp.quit.assert_called_once()
@patch("libs.smtp.smtplib.SMTP")
def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock):
mock_smtp = MagicMock()
mock_smtp.sendmail.side_effect = RuntimeError("oops")
mock_smtp_cls.return_value = mock_smtp
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
with pytest.raises(RuntimeError):
client.send(_mail())
mock_smtp.quit.assert_called_once()
@patch("libs.smtp.smtplib.SMTP")
def test_smtp_smtplib_exception_in_login(mock_smtp_cls: MagicMock):
# Ensure we hit the specific SMTPException except branch
import smtplib
mock_smtp = MagicMock()
mock_smtp.login.side_effect = smtplib.SMTPException("login-fail")
mock_smtp_cls.return_value = mock_smtp
client = SMTPClient(
server="smtp.example.com",
port=25,
username="user", # non-empty to trigger login
password="pass",
_from="noreply@example.com",
)
with pytest.raises(smtplib.SMTPException):
client.send(_mail())
mock_smtp.quit.assert_called_once()

View File

@ -28,18 +28,20 @@ class TestApiKeyAuthService:
mock_binding.provider = self.provider
mock_binding.disabled = False
mock_session.query.return_value.where.return_value.all.return_value = [mock_binding]
mock_session.scalars.return_value.all.return_value = [mock_binding]
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
assert len(result) == 1
assert result[0].tenant_id == self.tenant_id
mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
assert mock_session.scalars.call_count == 1
select_arg = mock_session.scalars.call_args[0][0]
assert "data_source_api_key_auth_binding" in str(select_arg).lower()
@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_empty(self, mock_session):
"""Test get provider auth list - empty result"""
mock_session.query.return_value.where.return_value.all.return_value = []
mock_session.scalars.return_value.all.return_value = []
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
@ -48,13 +50,15 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_filters_disabled(self, mock_session):
"""Test get provider auth list - filters disabled items"""
mock_session.query.return_value.where.return_value.all.return_value = []
mock_session.scalars.return_value.all.return_value = []
ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
# Verify where conditions include disabled.is_(False)
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 2 # tenant_id and disabled filter conditions
select_stmt = mock_session.scalars.call_args[0][0]
where_clauses = list(getattr(select_stmt, "_where_criteria", []) or [])
# Ensure both tenant filter and disabled filter exist
where_strs = [str(c).lower() for c in where_clauses]
assert any("tenant_id" in s for s in where_strs)
assert any("disabled" in s for s in where_strs)
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")

View File

@ -63,10 +63,10 @@ class TestAuthIntegration:
tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)
mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding]
mock_session.scalars.return_value.all.return_value = [tenant1_binding]
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)
mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding]
mock_session.scalars.return_value.all.return_value = [tenant2_binding]
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)
assert len(result1) == 1

View File

@ -24,7 +24,7 @@ const GA: FC<IGAProps> = ({
if (IS_CE_EDITION)
return null
const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') : ''
const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') ?? '' : ''
return (
<>
@ -32,7 +32,7 @@ const GA: FC<IGAProps> = ({
strategy="beforeInteractive"
async
src={`https://www.googletagmanager.com/gtag/js?id=${gaIdMaps[gaType]}`}
nonce={nonce!}
nonce={nonce ?? undefined}
></Script>
<Script
id="ga-init"
@ -44,14 +44,14 @@ gtag('js', new Date());
gtag('config', '${gaIdMaps[gaType]}');
`,
}}
nonce={nonce!}
nonce={nonce ?? undefined}
>
</Script>
{/* Cookie banner */}
<Script
id="cookieyes"
src='https://cdn-cookieyes.com/client_data/2a645945fcae53f8e025a2b1/script.js'
nonce={nonce!}
nonce={nonce ?? undefined}
></Script>
</>

View File

@ -1,5 +1,6 @@
import React, { useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useChatContext } from '../chat/chat/context'
const hasEndThink = (children: any): boolean => {
if (typeof children === 'string')
@ -35,6 +36,7 @@ const removeEndThink = (children: any): any => {
}
const useThinkTimer = (children: any) => {
const { isResponding } = useChatContext()
const [startTime] = useState(Date.now())
const [elapsedTime, setElapsedTime] = useState(0)
const [isComplete, setIsComplete] = useState(false)
@ -54,9 +56,9 @@ const useThinkTimer = (children: any) => {
}, [startTime, isComplete])
useEffect(() => {
if (hasEndThink(children))
if (hasEndThink(children) || !isResponding)
setIsComplete(true)
}, [children])
}, [children, isResponding])
return { elapsedTime, isComplete }
}

View File

@ -0,0 +1,21 @@
import { memo } from 'react'
import { type UnsafeUnwrappedHeaders, headers } from 'next/headers'
import Script from 'next/script'
import { IS_CE_EDITION, ZENDESK_WIDGET_KEY } from '@/config'
const Zendesk = () => {
if (IS_CE_EDITION || !ZENDESK_WIDGET_KEY)
return null
const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') ?? '' : ''
return (
<Script
nonce={nonce ?? undefined}
id="ze-snippet"
src={`https://static.zdassets.com/ekr/snippet.js?key=${ZENDESK_WIDGET_KEY}`}
/>
)
}
export default memo(Zendesk)

View File

@ -0,0 +1,23 @@
import { IS_CE_EDITION } from '@/config'
export type ConversationField = {
id: string,
value: any,
}
declare global {
// eslint-disable-next-line ts/consistent-type-definitions
interface Window {
zE?: (
command: string,
value: string,
payload?: ConversationField[] | string | string[] | (() => any),
callback?: () => any,
) => void;
}
}
export const setZendeskConversationFields = (fields: ConversationField[], callback?: () => any) => {
if (!IS_CE_EDITION && window.zE)
window.zE('messenger:set', 'conversationFields', fields, callback)
}

View File

@ -38,7 +38,7 @@ const Field: FC<Props> = ({
<div className={cn(className, inline && 'flex w-full items-center justify-between')}>
<div
onClick={() => supportFold && toggleFold()}
className={cn('sticky top-0 flex items-center justify-between bg-components-panel-bg', supportFold && 'cursor-pointer')}>
className={cn('flex items-center justify-between', supportFold && 'cursor-pointer')}>
<div className='flex h-6 items-center'>
<div className={cn(isSubTitle ? 'system-xs-medium-uppercase text-text-tertiary' : 'system-sm-semibold-uppercase text-text-secondary')}>
{title} {required && <span className='text-text-destructive'>*</span>}

Some files were not shown because too many files have changed in this diff Show More