mirror of https://github.com/langgenius/dify.git
Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine
This commit is contained in:
commit
85064bd8cf
|
|
@ -19,11 +19,23 @@ jobs:
|
|||
github.event.workflow_run.head_branch == 'deploy/enterprise'
|
||||
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.ENTERPRISE_SSH_HOST }}
|
||||
username: ${{ secrets.ENTERPRISE_SSH_USER }}
|
||||
password: ${{ secrets.ENTERPRISE_SSH_PASSWORD }}
|
||||
script: |
|
||||
${{ vars.ENTERPRISE_SSH_SCRIPT || secrets.ENTERPRISE_SSH_SCRIPT }}
|
||||
- name: trigger deployments
|
||||
env:
|
||||
DEV_ENV_ADDRS: ${{ vars.DEV_ENV_ADDRS }}
|
||||
DEPLOY_SECRET: ${{ secrets.DEPLOY_SECRET }}
|
||||
run: |
|
||||
IFS=',' read -ra ENDPOINTS <<< "${DEV_ENV_ADDRS:-}"
|
||||
BODY='{"project":"dify-api","tag":"deploy-enterprise"}'
|
||||
|
||||
for ENDPOINT in "${ENDPOINTS[@]}"; do
|
||||
ENDPOINT="$(echo "$ENDPOINT" | xargs)"
|
||||
[ -z "$ENDPOINT" ] && continue
|
||||
|
||||
API_SIGNATURE=$(printf '%s' "$BODY" | openssl dgst -sha256 -hmac "$DEPLOY_SECRET" | awk '{print "sha256="$2}')
|
||||
|
||||
curl -sSf -X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "X-Hub-Signature-256: $API_SIGNATURE" \
|
||||
-d "$BODY" \
|
||||
"$ENDPOINT"
|
||||
done
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ select = [
|
|||
"G001", # don't use str format to logging messages
|
||||
"G003", # don't use + in logging messages
|
||||
"G004", # don't use f-strings to format logging messages
|
||||
"UP042", # use StrEnum
|
||||
]
|
||||
|
||||
ignore = [
|
||||
|
|
|
|||
|
|
@ -212,7 +212,9 @@ def migrate_annotation_vector_database():
|
|||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all()
|
||||
annotations = db.session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
|
|
@ -367,29 +369,25 @@ def migrate_knowledge_vector_database():
|
|||
)
|
||||
raise e
|
||||
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
segments_count = 0
|
||||
for dataset_document in dataset_documents:
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
|
|
@ -26,7 +27,16 @@ from .files import FileApi, FilePreviewApi, FileSupportTypeApi
|
|||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
bp = Blueprint("console", __name__, url_prefix="/console/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="Console API",
|
||||
description="Console management APIs for app configuration, monitoring, and administration",
|
||||
)
|
||||
|
||||
# Create namespace
|
||||
console_ns = Namespace("console", description="Console management API operations", path="/")
|
||||
|
||||
# File
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
|
|
@ -43,7 +53,16 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm"
|
|||
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, ping, setup, version # pyright: ignore[reportUnusedImport]
|
||||
from . import (
|
||||
admin, # pyright: ignore[reportUnusedImport]
|
||||
apikey, # pyright: ignore[reportUnusedImport]
|
||||
extension, # pyright: ignore[reportUnusedImport]
|
||||
feature, # pyright: ignore[reportUnusedImport]
|
||||
init_validate, # pyright: ignore[reportUnusedImport]
|
||||
ping, # pyright: ignore[reportUnusedImport]
|
||||
setup, # pyright: ignore[reportUnusedImport]
|
||||
version, # pyright: ignore[reportUnusedImport]
|
||||
)
|
||||
|
||||
# Import app controllers
|
||||
from .app import (
|
||||
|
|
@ -103,6 +122,23 @@ from .explore import (
|
|||
saved_message, # pyright: ignore[reportUnusedImport]
|
||||
)
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import (
|
||||
account, # pyright: ignore[reportUnusedImport]
|
||||
agent_providers, # pyright: ignore[reportUnusedImport]
|
||||
endpoint, # pyright: ignore[reportUnusedImport]
|
||||
load_balancing_config, # pyright: ignore[reportUnusedImport]
|
||||
members, # pyright: ignore[reportUnusedImport]
|
||||
model_providers, # pyright: ignore[reportUnusedImport]
|
||||
models, # pyright: ignore[reportUnusedImport]
|
||||
plugin, # pyright: ignore[reportUnusedImport]
|
||||
tool_providers, # pyright: ignore[reportUnusedImport]
|
||||
workspace, # pyright: ignore[reportUnusedImport]
|
||||
)
|
||||
|
||||
# Explore Audio
|
||||
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
|
||||
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
|
||||
|
|
@ -174,19 +210,4 @@ api.add_resource(
|
|||
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
|
||||
)
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import (
|
||||
account, # pyright: ignore[reportUnusedImport]
|
||||
agent_providers, # pyright: ignore[reportUnusedImport]
|
||||
endpoint, # pyright: ignore[reportUnusedImport]
|
||||
load_balancing_config, # pyright: ignore[reportUnusedImport]
|
||||
members, # pyright: ignore[reportUnusedImport]
|
||||
model_providers, # pyright: ignore[reportUnusedImport]
|
||||
models, # pyright: ignore[reportUnusedImport]
|
||||
plugin, # pyright: ignore[reportUnusedImport]
|
||||
tool_providers, # pyright: ignore[reportUnusedImport]
|
||||
workspace, # pyright: ignore[reportUnusedImport]
|
||||
)
|
||||
api.add_namespace(console_ns)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from functools import wraps
|
|||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
|
@ -12,7 +12,7 @@ P = ParamSpec("P")
|
|||
R = TypeVar("R")
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
|
@ -45,7 +45,28 @@ def admin_required(view: Callable[P, R]):
|
|||
return decorated
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-apps")
|
||||
class InsertExploreAppListApi(Resource):
|
||||
@api.doc("insert_explore_app")
|
||||
@api.doc(description="Insert or update an app in the explore list")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"InsertExploreAppRequest",
|
||||
{
|
||||
"app_id": fields.String(required=True, description="Application ID"),
|
||||
"desc": fields.String(description="App description"),
|
||||
"copyright": fields.String(description="Copyright information"),
|
||||
"privacy_policy": fields.String(description="Privacy policy"),
|
||||
"custom_disclaimer": fields.String(description="Custom disclaimer"),
|
||||
"language": fields.String(required=True, description="Language code"),
|
||||
"category": fields.String(required=True, description="App category"),
|
||||
"position": fields.Integer(required=True, description="Display position"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "App updated successfully")
|
||||
@api.response(201, "App inserted successfully")
|
||||
@api.response(404, "App not found")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
|
|
@ -115,7 +136,12 @@ class InsertExploreAppListApi(Resource):
|
|||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
|
||||
class InsertExploreAppApi(Resource):
|
||||
@api.doc("delete_explore_app")
|
||||
@api.doc(description="Remove an app from the explore list")
|
||||
@api.doc(params={"app_id": "Application ID to remove"})
|
||||
@api.response(204, "App removed successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id):
|
||||
|
|
@ -152,7 +178,3 @@ class InsertExploreAppApi(Resource):
|
|||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps")
|
||||
api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/<uuid:app_id>")
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from libs.login import login_required
|
|||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
from . import api
|
||||
from . import api, console_ns
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
|
||||
api_key_fields = {
|
||||
|
|
@ -60,11 +60,11 @@ class BaseApiKeyListResource(Resource):
|
|||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
keys = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||
.all()
|
||||
)
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
).all()
|
||||
return {"items": keys}
|
||||
|
||||
@marshal_with(api_key_fields)
|
||||
|
|
@ -135,7 +135,25 @@ class BaseApiKeyResource(Resource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
|
||||
class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@api.doc("get_app_api_keys")
|
||||
@api.doc(description="Get all API keys for an app")
|
||||
@api.doc(params={"resource_id": "App ID"})
|
||||
@api.response(200, "Success", api_key_list)
|
||||
def get(self, resource_id):
|
||||
"""Get all API keys for an app"""
|
||||
return super().get(resource_id)
|
||||
|
||||
@api.doc("create_app_api_key")
|
||||
@api.doc(description="Create a new API key for an app")
|
||||
@api.doc(params={"resource_id": "App ID"})
|
||||
@api.response(201, "API key created successfully", api_key_fields)
|
||||
@api.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id):
|
||||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
|
@ -147,7 +165,16 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
|||
token_prefix = "app-"
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
class AppApiKeyResource(BaseApiKeyResource):
|
||||
@api.doc("delete_app_api_key")
|
||||
@api.doc(description="Delete an API key for an app")
|
||||
@api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
||||
@api.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id, api_key_id):
|
||||
"""Delete an API key for an app"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
|
@ -158,7 +185,25 @@ class AppApiKeyResource(BaseApiKeyResource):
|
|||
resource_id_field = "app_id"
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:resource_id>/api-keys")
|
||||
class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@api.doc("get_dataset_api_keys")
|
||||
@api.doc(description="Get all API keys for a dataset")
|
||||
@api.doc(params={"resource_id": "Dataset ID"})
|
||||
@api.response(200, "Success", api_key_list)
|
||||
def get(self, resource_id):
|
||||
"""Get all API keys for a dataset"""
|
||||
return super().get(resource_id)
|
||||
|
||||
@api.doc("create_dataset_api_key")
|
||||
@api.doc(description="Create a new API key for a dataset")
|
||||
@api.doc(params={"resource_id": "Dataset ID"})
|
||||
@api.response(201, "API key created successfully", api_key_fields)
|
||||
@api.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id):
|
||||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
|
@ -170,7 +215,16 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
|||
token_prefix = "ds-"
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
@api.doc("delete_dataset_api_key")
|
||||
@api.doc(description="Delete an API key for a dataset")
|
||||
@api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
||||
@api.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id, api_key_id):
|
||||
"""Delete an API key for a dataset"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
|
@ -179,9 +233,3 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
|||
resource_type = "dataset"
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
|
||||
|
||||
api.add_resource(AppApiKeyListResource, "/apps/<uuid:resource_id>/api-keys")
|
||||
api.add_resource(AppApiKeyResource, "/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
api.add_resource(DatasetApiKeyListResource, "/datasets/<uuid:resource_id>/api-keys")
|
||||
api.add_resource(DatasetApiKeyResource, "/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
|
|
@ -105,6 +105,12 @@ class ChatMessageApi(Resource):
|
|||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
def post(self, app_model):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ class MessageAnnotationApi(Resource):
|
|||
def post(self, app_model):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ import json
|
|||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
|
|
@ -13,7 +13,8 @@ from core.tools.tool_manager import ToolManager
|
|||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
|
@ -25,6 +26,13 @@ class ModelConfigResource(Resource):
|
|||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model):
|
||||
"""Modify app model config"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class DraftWorkflowApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
|
|
@ -93,7 +93,7 @@ class DraftWorkflowApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
|
@ -171,7 +171,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -221,7 +221,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -257,7 +257,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -294,7 +294,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -331,7 +331,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -368,7 +368,7 @@ class DraftWorkflowRunApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -407,7 +407,7 @@ class WorkflowTaskStopApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
|
|
@ -434,7 +434,7 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -482,7 +482,7 @@ class PublishedWorkflowApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# fetch published workflow by app_model
|
||||
|
|
@ -503,7 +503,7 @@ class PublishedWorkflowApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -553,7 +553,7 @@ class DefaultBlockConfigsApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# Get default block configs
|
||||
|
|
@ -573,7 +573,7 @@ class DefaultBlockConfigApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -608,7 +608,7 @@ class ConvertToWorkflowApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
if request.data:
|
||||
|
|
@ -657,7 +657,7 @@ class PublishedAllWorkflowApi(Resource):
|
|||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -708,7 +708,7 @@ class WorkflowByIdApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -764,7 +764,7 @@ class WorkflowByIdApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ def _api_prerequisite(f):
|
|||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def wrapper(*args, **kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -10,14 +10,36 @@ from libs.helper import StrLen, email, extract_remote_ip, timezone
|
|||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
active_check_parser = reqparse.RequestParser()
|
||||
active_check_parser.add_argument(
|
||||
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID"
|
||||
)
|
||||
active_check_parser.add_argument(
|
||||
"email", type=email, required=False, nullable=True, location="args", help="Email address"
|
||||
)
|
||||
active_check_parser.add_argument(
|
||||
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/activate/check")
|
||||
class ActivateCheckApi(Resource):
|
||||
@api.doc("check_activation_token")
|
||||
@api.doc(description="Check if activation token is valid")
|
||||
@api.expect(active_check_parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model(
|
||||
"ActivationCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether token is valid"),
|
||||
"data": fields.Raw(description="Activation data if valid"),
|
||||
},
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args")
|
||||
parser.add_argument("email", type=email, required=False, nullable=True, location="args")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
args = active_check_parser.parse_args()
|
||||
|
||||
workspaceId = args["workspace_id"]
|
||||
reg_email = args["email"]
|
||||
|
|
@ -38,18 +60,36 @@ class ActivateCheckApi(Resource):
|
|||
return {"is_valid": False}
|
||||
|
||||
|
||||
active_parser = reqparse.RequestParser()
|
||||
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
active_parser.add_argument(
|
||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||
)
|
||||
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/activate")
|
||||
class ActivateApi(Resource):
|
||||
@api.doc("activate_account")
|
||||
@api.doc(description="Activate account with invitation token")
|
||||
@api.expect(active_parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Account activated successfully",
|
||||
api.model(
|
||||
"ActivationResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
"data": fields.Raw(description="Login token data"),
|
||||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Already activated or invalid token")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
args = active_parser.parse_args()
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
|
||||
if invitation is None:
|
||||
|
|
@ -70,7 +110,3 @@ class ActivateApi(Resource):
|
|||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
api.add_resource(ActivateCheckApi, "/activate/check")
|
||||
api.add_resource(ActivateApi, "/activate")
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import logging
|
|||
import requests
|
||||
from flask import current_app, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
|
|
@ -28,7 +28,21 @@ def get_oauth_providers():
|
|||
return OAUTH_PROVIDERS
|
||||
|
||||
|
||||
@console_ns.route("/oauth/data-source/<string:provider>")
|
||||
class OAuthDataSource(Resource):
|
||||
@api.doc("oauth_data_source")
|
||||
@api.doc(description="Get OAuth authorization URL for data source provider")
|
||||
@api.doc(params={"provider": "Data source provider name (notion)"})
|
||||
@api.response(
|
||||
200,
|
||||
"Authorization URL or internal setup success",
|
||||
api.model(
|
||||
"OAuthDataSourceResponse",
|
||||
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid provider")
|
||||
@api.response(403, "Admin privileges required")
|
||||
def get(self, provider: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
|
|
@ -49,7 +63,19 @@ class OAuthDataSource(Resource):
|
|||
return {"data": auth_url}, 200
|
||||
|
||||
|
||||
@console_ns.route("/oauth/data-source/callback/<string:provider>")
|
||||
class OAuthDataSourceCallback(Resource):
|
||||
@api.doc("oauth_data_source_callback")
|
||||
@api.doc(description="Handle OAuth callback from data source provider")
|
||||
@api.doc(
|
||||
params={
|
||||
"provider": "Data source provider name (notion)",
|
||||
"code": "Authorization code from OAuth provider",
|
||||
"error": "Error message from OAuth provider",
|
||||
}
|
||||
)
|
||||
@api.response(302, "Redirect to console with result")
|
||||
@api.response(400, "Invalid provider")
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
|
|
@ -68,7 +94,19 @@ class OAuthDataSourceCallback(Resource):
|
|||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied")
|
||||
|
||||
|
||||
@console_ns.route("/oauth/data-source/binding/<string:provider>")
|
||||
class OAuthDataSourceBinding(Resource):
|
||||
@api.doc("oauth_data_source_binding")
|
||||
@api.doc(description="Bind OAuth data source with authorization code")
|
||||
@api.doc(
|
||||
params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"}
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Data source binding success",
|
||||
api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@api.response(400, "Invalid provider or code")
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
|
|
@ -90,7 +128,17 @@ class OAuthDataSourceBinding(Resource):
|
|||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
|
||||
class OAuthDataSourceSync(Resource):
|
||||
@api.doc("oauth_data_source_sync")
|
||||
@api.doc(description="Sync data from OAuth data source")
|
||||
@api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
|
||||
@api.response(
|
||||
200,
|
||||
"Data source sync success",
|
||||
api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@api.response(400, "Invalid provider or sync failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -111,9 +159,3 @@ class OAuthDataSourceSync(Resource):
|
|||
return {"error": "OAuth data source process failed"}, 400
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
|
||||
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
|
||||
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")
|
||||
api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
|
||||
|
|
|
|||
|
|
@ -2,12 +2,12 @@ import base64
|
|||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
EmailPasswordResetLimitError,
|
||||
|
|
@ -28,7 +28,32 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces
|
|||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password")
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@api.doc("send_forgot_password_email")
|
||||
@api.doc(description="Send password reset email")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ForgotPasswordEmailRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Email sent successfully",
|
||||
api.model(
|
||||
"ForgotPasswordEmailResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
"data": fields.String(description="Reset token"),
|
||||
"code": fields.String(description="Error code if account not found"),
|
||||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid email or rate limit exceeded")
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
|
|
@ -61,7 +86,33 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password/validity")
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@api.doc("check_forgot_password_code")
|
||||
@api.doc(description="Verify password reset code")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ForgotPasswordCheckRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
"code": fields.String(required=True, description="Verification code"),
|
||||
"token": fields.String(required=True, description="Reset token"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Code verified successfully",
|
||||
api.model(
|
||||
"ForgotPasswordCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether code is valid"),
|
||||
"email": fields.String(description="Email address"),
|
||||
"token": fields.String(description="New reset token"),
|
||||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid code or token")
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
|
|
@ -100,7 +151,26 @@ class ForgotPasswordCheckApi(Resource):
|
|||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password/resets")
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@api.doc("reset_password")
|
||||
@api.doc(description="Reset password with verification token")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ForgotPasswordResetRequest",
|
||||
{
|
||||
"token": fields.String(required=True, description="Verification token"),
|
||||
"new_password": fields.String(required=True, description="New password"),
|
||||
"password_confirm": fields.String(required=True, description="Password confirmation"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Password reset successfully",
|
||||
api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@api.response(400, "Invalid token or password mismatch")
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
|
|
@ -172,8 +242,3 @@ class ForgotPasswordResetApi(Resource):
|
|||
pass
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError
|
|||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import api
|
||||
from .. import api, console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -50,7 +50,13 @@ def get_oauth_providers():
|
|||
return OAUTH_PROVIDERS
|
||||
|
||||
|
||||
@console_ns.route("/oauth/login/<provider>")
|
||||
class OAuthLogin(Resource):
|
||||
@api.doc("oauth_login")
|
||||
@api.doc(description="Initiate OAuth login process")
|
||||
@api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"})
|
||||
@api.response(302, "Redirect to OAuth authorization URL")
|
||||
@api.response(400, "Invalid provider")
|
||||
def get(self, provider: str):
|
||||
invite_token = request.args.get("invite_token") or None
|
||||
OAUTH_PROVIDERS = get_oauth_providers()
|
||||
|
|
@ -63,7 +69,19 @@ class OAuthLogin(Resource):
|
|||
return redirect(auth_url)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/authorize/<provider>")
|
||||
class OAuthCallback(Resource):
|
||||
@api.doc("oauth_callback")
|
||||
@api.doc(description="Handle OAuth callback and complete login process")
|
||||
@api.doc(
|
||||
params={
|
||||
"provider": "OAuth provider name (github/google)",
|
||||
"code": "Authorization code from OAuth provider",
|
||||
"state": "Optional state parameter (used for invite token)",
|
||||
}
|
||||
)
|
||||
@api.response(302, "Redirect to console with access token")
|
||||
@api.response(400, "OAuth process failed")
|
||||
def get(self, provider: str):
|
||||
OAUTH_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
|
|
@ -184,7 +202,3 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||
|
||||
return account
|
||||
|
||||
|
||||
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
|
||||
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")
|
||||
|
|
|
|||
|
|
@ -29,14 +29,12 @@ class DataSourceApi(Resource):
|
|||
@marshal_with(integrate_list_fields)
|
||||
def get(self):
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.where(
|
||||
data_source_integrates = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
base_url = request.url_root.rstrip("/")
|
||||
data_source_oauth_base_path = "/console/api/oauth/data-source"
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import flask_restx
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -411,11 +412,11 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
extract_settings = []
|
||||
if args["info_list"]["data_source_type"] == "upload_file":
|
||||
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||
file_details = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
|
||||
.all()
|
||||
)
|
||||
file_details = db.session.scalars(
|
||||
select(UploadFile).where(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
|
||||
)
|
||||
).all()
|
||||
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
|
@ -518,11 +519,11 @@ class DatasetIndexingStatusApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
|
||||
.all()
|
||||
)
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
|
||||
)
|
||||
).all()
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
|
|
@ -569,11 +570,11 @@ class DatasetApiKeyApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_key_list)
|
||||
def get(self):
|
||||
keys = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
|
||||
.all()
|
||||
)
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(
|
||||
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
|
||||
)
|
||||
).all()
|
||||
return {"items": keys}
|
||||
|
||||
@setup_required
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, cast
|
||||
|
||||
from flask import request
|
||||
|
|
@ -79,7 +80,7 @@ class DocumentResource(Resource):
|
|||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Any
|
|||
|
||||
from flask import request
|
||||
from flask_restx import Resource, inputs, marshal_with, reqparse
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.console import api
|
||||
|
|
@ -33,13 +33,15 @@ class InstalledAppsListApi(Resource):
|
|||
current_tenant_id = current_user.current_tenant_id
|
||||
|
||||
if app_id:
|
||||
installed_apps = (
|
||||
db.session.query(InstalledApp)
|
||||
.where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
|
||||
.all()
|
||||
)
|
||||
installed_apps = db.session.scalars(
|
||||
select(InstalledApp).where(
|
||||
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
installed_apps = db.session.scalars(
|
||||
select(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id)
|
||||
).all()
|
||||
|
||||
if current_user.current_tenant is None:
|
||||
raise ValueError("current_user.current_tenant must not be None")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from libs.login import login_required
|
||||
|
|
@ -11,7 +11,21 @@ from services.api_based_extension_service import APIBasedExtensionService
|
|||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
||||
|
||||
@console_ns.route("/code-based-extension")
|
||||
class CodeBasedExtensionAPI(Resource):
|
||||
@api.doc("get_code_based_extension")
|
||||
@api.doc(description="Get code-based extension data by module name")
|
||||
@api.expect(
|
||||
api.parser().add_argument("module", type=str, required=True, location="args", help="Extension module name")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model(
|
||||
"CodeBasedExtensionResponse",
|
||||
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
|
||||
),
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -23,7 +37,11 @@ class CodeBasedExtensionAPI(Resource):
|
|||
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
|
||||
|
||||
|
||||
@console_ns.route("/api-based-extension")
|
||||
class APIBasedExtensionAPI(Resource):
|
||||
@api.doc("get_api_based_extensions")
|
||||
@api.doc(description="Get all API-based extensions for current tenant")
|
||||
@api.response(200, "Success", fields.List(fields.Nested(api_based_extension_fields)))
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -32,6 +50,19 @@ class APIBasedExtensionAPI(Resource):
|
|||
tenant_id = current_user.current_tenant_id
|
||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
|
||||
@api.doc("create_api_based_extension")
|
||||
@api.doc(description="Create a new API-based extension")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateAPIBasedExtensionRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="Extension name"),
|
||||
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
|
||||
"api_key": fields.String(required=True, description="API key for authentication"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "Extension created successfully", api_based_extension_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -53,7 +84,12 @@ class APIBasedExtensionAPI(Resource):
|
|||
return APIBasedExtensionService.save(extension_data)
|
||||
|
||||
|
||||
@console_ns.route("/api-based-extension/<uuid:id>")
|
||||
class APIBasedExtensionDetailAPI(Resource):
|
||||
@api.doc("get_api_based_extension")
|
||||
@api.doc(description="Get API-based extension by ID")
|
||||
@api.doc(params={"id": "Extension ID"})
|
||||
@api.response(200, "Success", api_based_extension_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -64,6 +100,20 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
|
||||
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
|
||||
@api.doc("update_api_based_extension")
|
||||
@api.doc(description="Update API-based extension")
|
||||
@api.doc(params={"id": "Extension ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateAPIBasedExtensionRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="Extension name"),
|
||||
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
|
||||
"api_key": fields.String(required=True, description="API key for authentication"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Extension updated successfully", api_based_extension_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -88,6 +138,10 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
|
||||
return APIBasedExtensionService.save(extension_data_from_db)
|
||||
|
||||
@api.doc("delete_api_based_extension")
|
||||
@api.doc(description="Delete API-based extension")
|
||||
@api.doc(params={"id": "Extension ID"})
|
||||
@api.response(204, "Extension deleted successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -100,9 +154,3 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
APIBasedExtensionService.delete(extension_data_from_db)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
api.add_resource(CodeBasedExtensionAPI, "/code-based-extension")
|
||||
|
||||
api.add_resource(APIBasedExtensionAPI, "/api-based-extension")
|
||||
api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/<uuid:id>")
|
||||
|
|
|
|||
|
|
@ -1,26 +1,40 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from libs.login import login_required
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api
|
||||
from . import api, console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
|
||||
@console_ns.route("/features")
|
||||
class FeatureApi(Resource):
|
||||
@api.doc("get_tenant_features")
|
||||
@api.doc(description="Get feature configuration for current tenant")
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get feature configuration for current tenant"""
|
||||
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/system-features")
|
||||
class SystemFeatureApi(Resource):
|
||||
@api.doc("get_system_features")
|
||||
@api.doc(description="Get system-wide feature configuration")
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model("SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}),
|
||||
)
|
||||
def get(self):
|
||||
"""Get system-wide feature configuration"""
|
||||
return FeatureService.get_system_features().model_dump()
|
||||
|
||||
|
||||
api.add_resource(FeatureApi, "/features")
|
||||
api.add_resource(SystemFeatureApi, "/system-features")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
|
||||
from flask import session
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -11,20 +11,47 @@ from libs.helper import StrLen
|
|||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
|
||||
from . import api
|
||||
from . import api, console_ns
|
||||
from .error import AlreadySetupError, InitValidateFailedError
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
|
||||
@console_ns.route("/init")
|
||||
class InitValidateAPI(Resource):
|
||||
@api.doc("get_init_status")
|
||||
@api.doc(description="Get initialization validation status")
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
model=api.model(
|
||||
"InitStatusResponse",
|
||||
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Get initialization validation status"""
|
||||
init_status = get_init_validate_status()
|
||||
if init_status:
|
||||
return {"status": "finished"}
|
||||
return {"status": "not_started"}
|
||||
|
||||
@api.doc("validate_init_password")
|
||||
@api.doc(description="Validate initialization password for self-hosted edition")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"InitValidateRequest",
|
||||
{"password": fields.String(required=True, description="Initialization password", max_length=30)},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
201,
|
||||
"Success",
|
||||
model=api.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@api.response(400, "Already setup or validation failed")
|
||||
@only_edition_self_hosted
|
||||
def post(self):
|
||||
"""Validate initialization password"""
|
||||
# is tenant created
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
|
|
@ -52,6 +79,3 @@ def get_init_validate_status():
|
|||
return db_session.execute(select(DifySetup)).scalar_one_or_none()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
api.add_resource(InitValidateAPI, "/init")
|
||||
|
|
|
|||
|
|
@ -1,14 +1,17 @@
|
|||
from flask_restx import Resource
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import api
|
||||
from . import api, console_ns
|
||||
|
||||
|
||||
@console_ns.route("/ping")
|
||||
class PingApi(Resource):
|
||||
@api.doc("health_check")
|
||||
@api.doc(description="Health check endpoint for connection testing")
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
|
||||
)
|
||||
def get(self):
|
||||
"""
|
||||
For connection health check
|
||||
"""
|
||||
"""Health check endpoint for connection testing"""
|
||||
return {"result": "pong"}
|
||||
|
||||
|
||||
api.add_resource(PingApi, "/ping")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import StrLen, email, extract_remote_ip
|
||||
|
|
@ -7,23 +7,56 @@ from libs.password import valid_password
|
|||
from models.model import DifySetup, db
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
from . import api
|
||||
from . import api, console_ns
|
||||
from .error import AlreadySetupError, NotInitValidateError
|
||||
from .init_validate import get_init_validate_status
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
|
||||
@console_ns.route("/setup")
|
||||
class SetupApi(Resource):
|
||||
@api.doc("get_setup_status")
|
||||
@api.doc(description="Get system setup status")
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model(
|
||||
"SetupStatusResponse",
|
||||
{
|
||||
"step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
|
||||
"setup_at": fields.String(description="Setup completion time (ISO format)", required=False),
|
||||
},
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Get system setup status"""
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
setup_status = get_setup_status()
|
||||
if setup_status:
|
||||
# Check if setup_status is a DifySetup object rather than a bool
|
||||
if setup_status and not isinstance(setup_status, bool):
|
||||
return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
|
||||
elif setup_status:
|
||||
return {"step": "finished"}
|
||||
return {"step": "not_started"}
|
||||
return {"step": "finished"}
|
||||
|
||||
@api.doc("setup_system")
|
||||
@api.doc(description="Initialize system setup with admin account")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"SetupRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Admin email address"),
|
||||
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
|
||||
"password": fields.String(required=True, description="Admin password"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "Success", api.model("SetupResponse", {"result": fields.String(description="Setup result")}))
|
||||
@api.response(400, "Already setup or validation failed")
|
||||
@only_edition_self_hosted
|
||||
def post(self):
|
||||
"""Initialize system setup with admin account"""
|
||||
# is set up
|
||||
if get_setup_status():
|
||||
raise AlreadySetupError()
|
||||
|
|
@ -55,6 +88,3 @@ def get_setup_status():
|
|||
return db.session.query(DifySetup).first()
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
api.add_resource(SetupApi, "/setup")
|
||||
|
|
|
|||
|
|
@ -2,18 +2,41 @@ import json
|
|||
import logging
|
||||
|
||||
import requests
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from packaging import version
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
from . import api
|
||||
from . import api, console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@console_ns.route("/version")
|
||||
class VersionApi(Resource):
|
||||
@api.doc("check_version_update")
|
||||
@api.doc(description="Check for application version updates")
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
"current_version", type=str, required=True, location="args", help="Current application version"
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model(
|
||||
"VersionResponse",
|
||||
{
|
||||
"version": fields.String(description="Latest version number"),
|
||||
"release_date": fields.String(description="Release date of latest version"),
|
||||
"release_notes": fields.String(description="Release notes for latest version"),
|
||||
"can_auto_update": fields.Boolean(description="Whether auto-update is supported"),
|
||||
"features": fields.Raw(description="Feature flags and capabilities"),
|
||||
},
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Check for application version updates"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("current_version", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
|
@ -59,6 +82,3 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool:
|
|||
except version.InvalidVersion:
|
||||
logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version)
|
||||
return False
|
||||
|
||||
|
||||
api.add_resource(VersionApi, "/version")
|
||||
|
|
|
|||
|
|
@ -248,7 +248,9 @@ class AccountIntegrateApi(Resource):
|
|||
raise ValueError("Invalid user account")
|
||||
account = current_user
|
||||
|
||||
account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
|
||||
account_integrates = db.session.scalars(
|
||||
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
|
||||
).all()
|
||||
|
||||
base_url = request.url_root.rstrip("/")
|
||||
oauth_base_path = "/console/api/oauth/login"
|
||||
|
|
|
|||
|
|
@ -1,14 +1,22 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/agent-providers")
|
||||
class AgentProviderListApi(Resource):
|
||||
@api.doc("list_agent_providers")
|
||||
@api.doc(description="Get list of available agent providers")
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
fields.List(fields.Raw(description="Agent provider information")),
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -21,7 +29,16 @@ class AgentProviderListApi(Resource):
|
|||
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
|
||||
class AgentProviderApi(Resource):
|
||||
@api.doc("get_agent_provider")
|
||||
@api.doc(description="Get specific agent provider details")
|
||||
@api.doc(params={"provider_name": "Agent provider name"})
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
fields.Raw(description="Agent provider details"),
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -30,7 +47,3 @@ class AgentProviderApi(Resource):
|
|||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
||||
|
||||
|
||||
api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers")
|
||||
api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
|
|
@ -10,7 +10,26 @@ from libs.login import login_required
|
|||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
@api.doc("create_endpoint")
|
||||
@api.doc(description="Create a new plugin endpoint")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"EndpointCreateRequest",
|
||||
{
|
||||
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
|
||||
"settings": fields.Raw(required=True, description="Endpoint settings"),
|
||||
"name": fields.String(required=True, description="Endpoint name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
api.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
)
|
||||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -43,7 +62,20 @@ class EndpointCreateApi(Resource):
|
|||
raise ValueError(e.description) from e
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/list")
|
||||
class EndpointListApi(Resource):
|
||||
@api.doc("list_endpoints")
|
||||
@api.doc(description="List plugin endpoints with pagination")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model("EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}),
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -70,7 +102,23 @@ class EndpointListApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/list/plugin")
|
||||
class EndpointListForSinglePluginApi(Resource):
|
||||
@api.doc("list_plugin_endpoints")
|
||||
@api.doc(description="List endpoints for a specific plugin")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
||||
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model(
|
||||
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
|
||||
),
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -100,7 +148,19 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
class EndpointDeleteApi(Resource):
|
||||
@api.doc("delete_endpoint")
|
||||
@api.doc(description="Delete a plugin endpoint")
|
||||
@api.expect(
|
||||
api.model("EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
api.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
)
|
||||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -123,7 +183,26 @@ class EndpointDeleteApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/update")
|
||||
class EndpointUpdateApi(Resource):
|
||||
@api.doc("update_endpoint")
|
||||
@api.doc(description="Update a plugin endpoint")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"EndpointUpdateRequest",
|
||||
{
|
||||
"endpoint_id": fields.String(required=True, description="Endpoint ID"),
|
||||
"settings": fields.Raw(required=True, description="Updated settings"),
|
||||
"name": fields.String(required=True, description="Updated name"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
api.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
)
|
||||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -154,7 +233,19 @@ class EndpointUpdateApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/enable")
|
||||
class EndpointEnableApi(Resource):
|
||||
@api.doc("enable_endpoint")
|
||||
@api.doc(description="Enable a plugin endpoint")
|
||||
@api.expect(
|
||||
api.model("EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Endpoint enabled successfully",
|
||||
api.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
)
|
||||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -177,7 +268,19 @@ class EndpointEnableApi(Resource):
|
|||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/disable")
|
||||
class EndpointDisableApi(Resource):
|
||||
@api.doc("disable_endpoint")
|
||||
@api.doc(description="Disable a plugin endpoint")
|
||||
@api.expect(
|
||||
api.model("EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Endpoint disabled successfully",
|
||||
api.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
)
|
||||
@api.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -198,12 +301,3 @@ class EndpointDisableApi(Resource):
|
|||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create")
|
||||
api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list")
|
||||
api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin")
|
||||
api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete")
|
||||
api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update")
|
||||
api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable")
|
||||
api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable")
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ api = ExternalApi(
|
|||
version="1.0",
|
||||
title="Files API",
|
||||
description="API for file operations including upload and preview",
|
||||
doc="/docs", # Enable Swagger UI at /files/docs
|
||||
)
|
||||
|
||||
files_ns = Namespace("files", description="File operations", path="/")
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ api = ExternalApi(
|
|||
version="1.0",
|
||||
title="Inner API",
|
||||
description="Internal APIs for enterprise features, billing, and plugin communication",
|
||||
doc="/docs", # Enable Swagger UI at /inner/api/docs
|
||||
)
|
||||
|
||||
# Create namespace
|
||||
|
|
|
|||
|
|
@ -75,9 +75,6 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None):
|
|||
if not user_id:
|
||||
user_id = DEFAULT_SERVICE_API_USER_ID
|
||||
|
||||
del kwargs["tenant_id"]
|
||||
del kwargs["user_id"]
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
db.session.query(Tenant)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ api = ExternalApi(
|
|||
version="1.0",
|
||||
title="MCP API",
|
||||
description="API for Model Context Protocol operations",
|
||||
doc="/docs", # Enable Swagger UI at /mcp/docs
|
||||
)
|
||||
|
||||
mcp_ns = Namespace("mcp", description="MCP operations", path="/")
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ api = ExternalApi(
|
|||
version="1.0",
|
||||
title="Service API",
|
||||
description="API for application services",
|
||||
doc="/docs", # Enable Swagger UI at /v1/docs
|
||||
)
|
||||
|
||||
service_api_ns = Namespace("service_api", description="Service operations", path="/")
|
||||
|
|
|
|||
|
|
@ -165,7 +165,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
def put(self, app_model: App, annotation_id):
|
||||
"""Update an existing annotation."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
annotation_id = str(annotation_id)
|
||||
|
|
@ -189,7 +189,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
"""Delete an annotation."""
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
annotation_id = str(annotation_id)
|
||||
|
|
|
|||
|
|
@ -559,7 +559,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
def post(self, _, dataset_id):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_create_parser.parse_args()
|
||||
|
|
@ -583,7 +583,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_update_parser.parse_args()
|
||||
|
|
@ -610,7 +610,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
args = tag_delete_parser.parse_args()
|
||||
TagService.delete_tag(args.get("tag_id"))
|
||||
|
|
@ -634,7 +634,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
|||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_binding_parser.parse_args()
|
||||
|
|
@ -660,7 +660,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
|||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_unbinding_parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ api = ExternalApi(
|
|||
version="1.0",
|
||||
title="Web API",
|
||||
description="Public APIs for web applications including file uploads, chat interactions, and app management",
|
||||
doc="/docs", # Enable Swagger UI at /api/docs
|
||||
)
|
||||
|
||||
# Create namespace
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask_restx import fields, marshal_with, reqparse
|
|||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
|
|
@ -32,15 +32,16 @@ from services.errors.audio import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@web_ns.route("/audio-to-text")
|
||||
class AudioApi(WebApiResource):
|
||||
audio_to_text_response_fields = {
|
||||
"text": fields.String,
|
||||
}
|
||||
|
||||
@marshal_with(audio_to_text_response_fields)
|
||||
@api.doc("Audio to Text")
|
||||
@api.doc(description="Convert audio file to text using speech-to-text service.")
|
||||
@api.doc(
|
||||
@web_ns.doc("Audio to Text")
|
||||
@web_ns.doc(description="Convert audio file to text using speech-to-text service.")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
|
|
@ -85,6 +86,7 @@ class AudioApi(WebApiResource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
@web_ns.route("/text-to-audio")
|
||||
class TextApi(WebApiResource):
|
||||
text_to_audio_response_fields = {
|
||||
"audio_url": fields.String,
|
||||
|
|
@ -92,9 +94,9 @@ class TextApi(WebApiResource):
|
|||
}
|
||||
|
||||
@marshal_with(text_to_audio_response_fields)
|
||||
@api.doc("Text to Audio")
|
||||
@api.doc(description="Convert text to audio using text-to-speech service.")
|
||||
@api.doc(
|
||||
@web_ns.doc("Text to Audio")
|
||||
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
|
|
@ -145,7 +147,3 @@ class TextApi(WebApiResource):
|
|||
except Exception as e:
|
||||
logger.exception("Failed to handle post request to TextApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(AudioApi, "/audio-to-text")
|
||||
api.add_resource(TextApi, "/text-to-audio")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from flask_restx import reqparse
|
|||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
|
|
@ -35,10 +35,11 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
# define completion api for user
|
||||
@web_ns.route("/completion-messages")
|
||||
class CompletionApi(WebApiResource):
|
||||
@api.doc("Create Completion Message")
|
||||
@api.doc(description="Create a completion message for text generation applications.")
|
||||
@api.doc(
|
||||
@web_ns.doc("Create Completion Message")
|
||||
@web_ns.doc(description="Create a completion message for text generation applications.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
|
||||
"query": {"description": "Query text for completion", "type": "string", "required": False},
|
||||
|
|
@ -52,7 +53,7 @@ class CompletionApi(WebApiResource):
|
|||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@api.doc(
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
|
|
@ -106,11 +107,12 @@ class CompletionApi(WebApiResource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
@web_ns.route("/completion-messages/<string:task_id>/stop")
|
||||
class CompletionStopApi(WebApiResource):
|
||||
@api.doc("Stop Completion Message")
|
||||
@api.doc(description="Stop a running completion message task.")
|
||||
@api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
|
||||
@api.doc(
|
||||
@web_ns.doc("Stop Completion Message")
|
||||
@web_ns.doc(description="Stop a running completion message task.")
|
||||
@web_ns.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
|
|
@ -129,10 +131,11 @@ class CompletionStopApi(WebApiResource):
|
|||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@web_ns.route("/chat-messages")
|
||||
class ChatApi(WebApiResource):
|
||||
@api.doc("Create Chat Message")
|
||||
@api.doc(description="Create a chat message for conversational applications.")
|
||||
@api.doc(
|
||||
@web_ns.doc("Create Chat Message")
|
||||
@web_ns.doc(description="Create a chat message for conversational applications.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
|
||||
"query": {"description": "User query/message", "type": "string", "required": True},
|
||||
|
|
@ -148,7 +151,7 @@ class ChatApi(WebApiResource):
|
|||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@api.doc(
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
|
|
@ -207,11 +210,12 @@ class ChatApi(WebApiResource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
@web_ns.route("/chat-messages/<string:task_id>/stop")
|
||||
class ChatStopApi(WebApiResource):
|
||||
@api.doc("Stop Chat Message")
|
||||
@api.doc(description="Stop a running chat message task.")
|
||||
@api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
|
||||
@api.doc(
|
||||
@web_ns.doc("Stop Chat Message")
|
||||
@web_ns.doc(description="Stop a running chat message task.")
|
||||
@web_ns.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}})
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
|
|
@ -229,9 +233,3 @@ class ChatStopApi(WebApiResource):
|
|||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(CompletionApi, "/completion-messages")
|
||||
api.add_resource(CompletionStopApi, "/completion-messages/<string:task_id>/stop")
|
||||
api.add_resource(ChatApi, "/chat-messages")
|
||||
api.add_resource(ChatStopApi, "/chat-messages/<string:task_id>/stop")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from flask_restx.inputs import int_range
|
|||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotChatAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
@ -16,7 +16,44 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers
|
|||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
|
||||
@web_ns.route("/conversations")
|
||||
class ConversationListApi(WebApiResource):
|
||||
@web_ns.doc("Get Conversation List")
|
||||
@web_ns.doc(description="Retrieve paginated list of conversations for a chat application.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"last_id": {"description": "Last conversation ID for pagination", "type": "string", "required": False},
|
||||
"limit": {
|
||||
"description": "Number of conversations to return (1-100)",
|
||||
"type": "integer",
|
||||
"required": False,
|
||||
"default": 20,
|
||||
},
|
||||
"pinned": {
|
||||
"description": "Filter by pinned status",
|
||||
"type": "string",
|
||||
"enum": ["true", "false"],
|
||||
"required": False,
|
||||
},
|
||||
"sort_by": {
|
||||
"description": "Sort order",
|
||||
"type": "string",
|
||||
"enum": ["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
"required": False,
|
||||
"default": "-updated_at",
|
||||
},
|
||||
}
|
||||
)
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found or Not a Chat App",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
|
@ -57,11 +94,25 @@ class ConversationListApi(WebApiResource):
|
|||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>")
|
||||
class ConversationApi(WebApiResource):
|
||||
delete_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Delete Conversation")
|
||||
@web_ns.doc(description="Delete a specific conversation.")
|
||||
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
204: "Conversation deleted successfully",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Conversation Not Found or Not a Chat App",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(delete_response_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
|
@ -76,7 +127,32 @@ class ConversationApi(WebApiResource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/name")
|
||||
class ConversationRenameApi(WebApiResource):
|
||||
@web_ns.doc("Rename Conversation")
|
||||
@web_ns.doc(description="Rename a specific conversation with a custom name or auto-generate one.")
|
||||
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"name": {"description": "New conversation name", "type": "string", "required": False},
|
||||
"auto_generate": {
|
||||
"description": "Auto-generate conversation name",
|
||||
"type": "boolean",
|
||||
"required": False,
|
||||
"default": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Conversation renamed successfully",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Conversation Not Found or Not a Chat App",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
|
@ -96,11 +172,25 @@ class ConversationRenameApi(WebApiResource):
|
|||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/pin")
|
||||
class ConversationPinApi(WebApiResource):
|
||||
pin_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Pin Conversation")
|
||||
@web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.")
|
||||
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Conversation pinned successfully",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Conversation Not Found or Not a Chat App",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(pin_response_fields)
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
|
@ -117,11 +207,25 @@ class ConversationPinApi(WebApiResource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/unpin")
|
||||
class ConversationUnPinApi(WebApiResource):
|
||||
unpin_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Unpin Conversation")
|
||||
@web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.")
|
||||
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Conversation unpinned successfully",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Conversation Not Found or Not a Chat App",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(unpin_response_fields)
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
|
@ -132,10 +236,3 @@ class ConversationUnPinApi(WebApiResource):
|
|||
WebConversationService.unpin(app_model, conversation_id, end_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="web_conversation_name")
|
||||
api.add_resource(ConversationListApi, "/conversations")
|
||||
api.add_resource(ConversationApi, "/conversations/<uuid:c_id>")
|
||||
api.add_resource(ConversationPinApi, "/conversations/<uuid:c_id>/pin")
|
||||
api.add_resource(ConversationUnPinApi, "/conversations/<uuid:c_id>/unpin")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from flask_restx import fields, marshal_with, reqparse
|
|||
from flask_restx.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
|
|
@ -38,6 +38,7 @@ from services.message_service import MessageService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@web_ns.route("/messages")
|
||||
class MessageListApi(WebApiResource):
|
||||
message_fields = {
|
||||
"id": fields.String,
|
||||
|
|
@ -62,6 +63,30 @@ class MessageListApi(WebApiResource):
|
|||
"data": fields.List(fields.Nested(message_fields)),
|
||||
}
|
||||
|
||||
@web_ns.doc("Get Message List")
|
||||
@web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": True},
|
||||
"first_id": {"description": "First message ID for pagination", "type": "string", "required": False},
|
||||
"limit": {
|
||||
"description": "Number of messages to return (1-100)",
|
||||
"type": "integer",
|
||||
"required": False,
|
||||
"default": 20,
|
||||
},
|
||||
}
|
||||
)
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Conversation Not Found or Not a Chat App",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
|
@ -84,11 +109,36 @@ class MessageListApi(WebApiResource):
|
|||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
@web_ns.route("/messages/<uuid:message_id>/feedbacks")
|
||||
class MessageFeedbackApi(WebApiResource):
|
||||
feedback_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Create Message Feedback")
|
||||
@web_ns.doc(description="Submit feedback (like/dislike) for a specific message.")
|
||||
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"rating": {
|
||||
"description": "Feedback rating",
|
||||
"type": "string",
|
||||
"enum": ["like", "dislike"],
|
||||
"required": False,
|
||||
},
|
||||
"content": {"description": "Feedback content/comment", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Feedback submitted successfully",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Message Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(feedback_response_fields)
|
||||
def post(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
|
@ -112,7 +162,31 @@ class MessageFeedbackApi(WebApiResource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@web_ns.route("/messages/<uuid:message_id>/more-like-this")
|
||||
class MessageMoreLikeThisApi(WebApiResource):
|
||||
@web_ns.doc("Generate More Like This")
|
||||
@web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"message_id": {"description": "Message UUID", "type": "string", "required": True},
|
||||
"response_mode": {
|
||||
"description": "Response mode",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request - Not a completion app or feature disabled",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Message Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model, end_user, message_id):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
|
@ -156,11 +230,25 @@ class MessageMoreLikeThisApi(WebApiResource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
@web_ns.route("/messages/<uuid:message_id>/suggested-questions")
|
||||
class MessageSuggestedQuestionApi(WebApiResource):
|
||||
suggested_questions_response_fields = {
|
||||
"data": fields.List(fields.String),
|
||||
}
|
||||
|
||||
@web_ns.doc("Get Suggested Questions")
|
||||
@web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).")
|
||||
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request - Not a chat app or feature disabled",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Message Not Found or Conversation Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(suggested_questions_response_fields)
|
||||
def get(self, app_model, end_user, message_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
|
@ -192,9 +280,3 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
|||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
|
||||
|
||||
api.add_resource(MessageListApi, "/messages")
|
||||
api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks")
|
||||
api.add_resource(MessageMoreLikeThisApi, "/messages/<uuid:message_id>/more-like-this")
|
||||
api.add_resource(MessageSuggestedQuestionApi, "/messages/<uuid:message_id>/suggested-questions")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse
|
|||
from flask_restx.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
|
@ -23,6 +23,7 @@ message_fields = {
|
|||
}
|
||||
|
||||
|
||||
@web_ns.route("/saved-messages")
|
||||
class SavedMessageListApi(WebApiResource):
|
||||
saved_message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
|
|
@ -34,6 +35,29 @@ class SavedMessageListApi(WebApiResource):
|
|||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Get Saved Messages")
|
||||
@web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"last_id": {"description": "Last message ID for pagination", "type": "string", "required": False},
|
||||
"limit": {
|
||||
"description": "Number of messages to return (1-100)",
|
||||
"type": "integer",
|
||||
"required": False,
|
||||
"default": 20,
|
||||
},
|
||||
}
|
||||
)
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request - Not a completion app",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "App Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
|
|
@ -46,6 +70,23 @@ class SavedMessageListApi(WebApiResource):
|
|||
|
||||
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
|
||||
|
||||
@web_ns.doc("Save Message")
|
||||
@web_ns.doc(description="Save a specific message for later reference.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"message_id": {"description": "Message UUID to save", "type": "string", "required": True},
|
||||
}
|
||||
)
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Message saved successfully",
|
||||
400: "Bad Request - Not a completion app",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Message Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(post_response_fields)
|
||||
def post(self, app_model, end_user):
|
||||
if app_model.mode != "completion":
|
||||
|
|
@ -63,11 +104,25 @@ class SavedMessageListApi(WebApiResource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
@web_ns.route("/saved-messages/<uuid:message_id>")
|
||||
class SavedMessageApi(WebApiResource):
|
||||
delete_response_fields = {
|
||||
"result": fields.String,
|
||||
}
|
||||
|
||||
@web_ns.doc("Delete Saved Message")
|
||||
@web_ns.doc(description="Remove a message from saved messages.")
|
||||
@web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}})
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
204: "Message removed successfully",
|
||||
400: "Bad Request - Not a completion app",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Message Not Found",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(delete_response_fields)
|
||||
def delete(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
|
|
@ -78,7 +133,3 @@ class SavedMessageApi(WebApiResource):
|
|||
SavedMessageService.delete(app_model, end_user, message_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
api.add_resource(SavedMessageListApi, "/saved-messages")
|
||||
api.add_resource(SavedMessageApi, "/saved-messages/<uuid:message_id>")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import AppIconUrlField
|
||||
|
|
@ -11,6 +11,7 @@ from models.model import Site
|
|||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@web_ns.route("/site")
|
||||
class AppSiteApi(WebApiResource):
|
||||
"""Resource for app sites."""
|
||||
|
||||
|
|
@ -53,9 +54,9 @@ class AppSiteApi(WebApiResource):
|
|||
"custom_config": fields.Raw(attribute="custom_config"),
|
||||
}
|
||||
|
||||
@api.doc("Get App Site Info")
|
||||
@api.doc(description="Retrieve app site information and configuration.")
|
||||
@api.doc(
|
||||
@web_ns.doc("Get App Site Info")
|
||||
@web_ns.doc(description="Retrieve app site information and configuration.")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
|
|
@ -82,9 +83,6 @@ class AppSiteApi(WebApiResource):
|
|||
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
|
||||
|
||||
|
||||
api.add_resource(AppSiteApi, "/site")
|
||||
|
||||
|
||||
class AppSiteInfo:
|
||||
"""Class to store site information."""
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
from flask_restx import reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
CompletionRequestError,
|
||||
NotWorkflowAppError,
|
||||
|
|
@ -30,16 +30,17 @@ from services.errors.llm import InvokeRateLimitError
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@web_ns.route("/workflows/run")
|
||||
class WorkflowRunApi(WebApiResource):
|
||||
@api.doc("Run Workflow")
|
||||
@api.doc(description="Execute a workflow with provided inputs and files.")
|
||||
@api.doc(
|
||||
@web_ns.doc("Run Workflow")
|
||||
@web_ns.doc(description="Execute a workflow with provided inputs and files.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the workflow", "type": "object", "required": True},
|
||||
"files": {"description": "Files to be processed by the workflow", "type": "array", "required": False},
|
||||
}
|
||||
)
|
||||
@api.doc(
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
|
|
@ -85,15 +86,16 @@ class WorkflowRunApi(WebApiResource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
@web_ns.route("/workflows/tasks/<string:task_id>/stop")
|
||||
class WorkflowTaskStopApi(WebApiResource):
|
||||
@api.doc("Stop Workflow Task")
|
||||
@api.doc(description="Stop a running workflow task.")
|
||||
@api.doc(
|
||||
@web_ns.doc("Stop Workflow Task")
|
||||
@web_ns.doc(description="Stop a running workflow task.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"task_id": {"description": "Task ID to stop", "type": "string", "required": True},
|
||||
}
|
||||
)
|
||||
@api.doc(
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
400: "Bad Request",
|
||||
|
|
@ -119,7 +121,3 @@ class WorkflowTaskStopApi(WebApiResource):
|
|||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(WorkflowRunApi, "/workflows/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
|
||||
|
|
|
|||
|
|
@ -32,11 +32,16 @@ class TokenBufferMemory:
|
|||
self.model_instance = model_instance
|
||||
|
||||
def _build_prompt_message_with_files(
|
||||
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
|
||||
self,
|
||||
message_files: Sequence[MessageFile],
|
||||
text_content: str,
|
||||
message: Message,
|
||||
app_record,
|
||||
is_user_message: bool,
|
||||
) -> PromptMessage:
|
||||
"""
|
||||
Build prompt message with files.
|
||||
:param message_files: list of MessageFile objects
|
||||
:param message_files: Sequence of MessageFile objects
|
||||
:param text_content: text content of the message
|
||||
:param message: Message object
|
||||
:param app_record: app record
|
||||
|
|
@ -128,14 +133,12 @@ class TokenBufferMemory:
|
|||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
# Process user message with files
|
||||
user_files = (
|
||||
db.session.query(MessageFile)
|
||||
.where(
|
||||
user_files = db.session.scalars(
|
||||
select(MessageFile).where(
|
||||
MessageFile.message_id == message.id,
|
||||
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if user_files:
|
||||
user_prompt_message = self._build_prompt_message_with_files(
|
||||
|
|
@ -150,11 +153,9 @@ class TokenBufferMemory:
|
|||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
|
||||
# Process assistant message with files
|
||||
assistant_files = (
|
||||
db.session.query(MessageFile)
|
||||
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
|
||||
.all()
|
||||
)
|
||||
assistant_files = db.session.scalars(
|
||||
select(MessageFile).where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
|
||||
).all()
|
||||
|
||||
if assistant_files:
|
||||
assistant_prompt_message = self._build_prompt_message_with_files(
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from opentelemetry.sdk.resources import Resource
|
|||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
|
||||
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
|
|
@ -699,8 +700,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
|
||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||
"""Helper method to get workflow nodes"""
|
||||
workflow_nodes = (
|
||||
db.session.query(
|
||||
workflow_nodes = db.session.scalars(
|
||||
select(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
|
|
@ -713,10 +714,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.process_data,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
)
|
||||
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
).all()
|
||||
return workflow_nodes
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import time
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPDigestAuth
|
||||
|
|
@ -139,7 +140,7 @@ class TidbService:
|
|||
|
||||
@staticmethod
|
||||
def batch_update_tidb_serverless_cluster_status(
|
||||
tidb_serverless_list: list[TidbAuthBinding],
|
||||
tidb_serverless_list: Sequence[TidbAuthBinding],
|
||||
project_id: str,
|
||||
api_url: str,
|
||||
iam_url: str,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from pydantic import Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
|
@ -176,11 +177,11 @@ class ApiToolProviderController(ToolProviderController):
|
|||
tools: list[ApiTool] = []
|
||||
|
||||
# get tenant api providers
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
|
||||
.all()
|
||||
)
|
||||
db_providers = db.session.scalars(
|
||||
select(ApiToolProvider).where(
|
||||
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name
|
||||
)
|
||||
).all()
|
||||
|
||||
if db_providers and len(db_providers) != 0:
|
||||
for db_provider in db_providers:
|
||||
|
|
|
|||
|
|
@ -87,9 +87,7 @@ class ToolLabelManager:
|
|||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
||||
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
|
||||
|
||||
labels: list[ToolLabelBinding] = (
|
||||
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all()
|
||||
)
|
||||
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
|
||||
|
||||
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
|
||||
|
||||
|
|
|
|||
|
|
@ -671,9 +671,9 @@ class ToolManager:
|
|||
|
||||
# get db api providers
|
||||
if "api" in filters:
|
||||
db_api_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
db_api_providers = db.session.scalars(
|
||||
select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
||||
|
|
@ -694,9 +694,9 @@ class ToolManager:
|
|||
|
||||
if "workflow" in filters:
|
||||
# get workflow providers
|
||||
workflow_providers: list[WorkflowToolProvider] = (
|
||||
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
workflow_providers = db.session.scalars(
|
||||
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
for workflow_provider in workflow_providers:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from sqlalchemy import select
|
||||
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import AppDatasetJoin
|
||||
|
|
@ -13,7 +15,7 @@ def handle(sender, **kwargs):
|
|||
|
||||
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
|
||||
|
||||
app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
|
||||
app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all()
|
||||
|
||||
removed_dataset_ids: set[str] = set()
|
||||
if not app_dataset_joins:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from events.app_event import app_published_workflow_was_updated
|
||||
|
|
@ -15,7 +17,7 @@ def handle(sender, **kwargs):
|
|||
published_workflow = cast(Workflow, published_workflow)
|
||||
|
||||
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
|
||||
app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
|
||||
app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all()
|
||||
|
||||
removed_dataset_ids: set[str] = set()
|
||||
if not app_dataset_joins:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import sqlalchemy as sa
|
|||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from models.base import Base
|
||||
|
||||
|
|
@ -187,7 +188,28 @@ class Account(UserMixin, Base):
|
|||
return TenantAccountRole.is_admin_role(self.role)
|
||||
|
||||
@property
|
||||
@deprecated("Use has_edit_permission instead.")
|
||||
def is_editor(self):
|
||||
"""Determines if the account has edit permissions in their current tenant (workspace).
|
||||
|
||||
This property checks if the current role has editing privileges, which includes:
|
||||
- `OWNER`
|
||||
- `ADMIN`
|
||||
- `EDITOR`
|
||||
|
||||
Note: This checks for any role with editing permission, not just the 'EDITOR' role specifically.
|
||||
"""
|
||||
return self.has_edit_permission
|
||||
|
||||
@property
|
||||
def has_edit_permission(self):
|
||||
"""Determines if the account has editing permissions in their current tenant (workspace).
|
||||
|
||||
This property checks if the current role has editing privileges, which includes:
|
||||
- `OWNER`
|
||||
- `ADMIN`
|
||||
- `EDITOR`
|
||||
"""
|
||||
return TenantAccountRole.is_editing_role(self.role)
|
||||
|
||||
@property
|
||||
|
|
@ -218,10 +240,12 @@ class Tenant(Base):
|
|||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
|
||||
def get_accounts(self) -> list[Account]:
|
||||
return (
|
||||
db.session.query(Account)
|
||||
.where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id)
|
||||
.all()
|
||||
return list(
|
||||
db.session.scalars(
|
||||
select(Account).where(
|
||||
Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id
|
||||
)
|
||||
).all()
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -208,7 +208,9 @@ class Dataset(Base):
|
|||
|
||||
@property
|
||||
def doc_metadata(self):
|
||||
dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all()
|
||||
dataset_metadatas = db.session.scalars(
|
||||
select(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id)
|
||||
).all()
|
||||
|
||||
doc_metadata = [
|
||||
{
|
||||
|
|
@ -1055,13 +1057,11 @@ class ExternalKnowledgeApis(Base):
|
|||
|
||||
@property
|
||||
def dataset_bindings(self) -> list[dict[str, Any]]:
|
||||
external_knowledge_bindings = (
|
||||
db.session.query(ExternalKnowledgeBindings)
|
||||
.where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
|
||||
.all()
|
||||
)
|
||||
external_knowledge_bindings = db.session.scalars(
|
||||
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
|
||||
).all()
|
||||
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
|
||||
datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
|
||||
datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
|
||||
dataset_bindings: list[dict[str, Any]] = []
|
||||
for dataset in datasets:
|
||||
dataset_bindings.append({"id": dataset.id, "name": dataset.name})
|
||||
|
|
|
|||
|
|
@ -811,7 +811,7 @@ class Conversation(Base):
|
|||
|
||||
@property
|
||||
def status_count(self):
|
||||
messages = db.session.query(Message).where(Message.conversation_id == self.id).all()
|
||||
messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all()
|
||||
status_counts = {
|
||||
WorkflowExecutionStatus.RUNNING: 0,
|
||||
WorkflowExecutionStatus.SUCCEEDED: 0,
|
||||
|
|
@ -1090,7 +1090,7 @@ class Message(Base):
|
|||
|
||||
@property
|
||||
def feedbacks(self):
|
||||
feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all()
|
||||
feedbacks = db.session.scalars(select(MessageFeedback).where(MessageFeedback.message_id == self.id)).all()
|
||||
return feedbacks
|
||||
|
||||
@property
|
||||
|
|
@ -1145,7 +1145,7 @@ class Message(Base):
|
|||
def message_files(self) -> list[dict[str, Any]]:
|
||||
from factories import file_factory
|
||||
|
||||
message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all()
|
||||
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
|
||||
current_app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
if not current_app:
|
||||
raise ValueError(f"App {self.app_id} not found")
|
||||
|
|
|
|||
|
|
@ -96,11 +96,11 @@ def clean_unused_datasets_task():
|
|||
break
|
||||
|
||||
for dataset in datasets:
|
||||
dataset_query = (
|
||||
db.session.query(DatasetQuery)
|
||||
.where(DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id)
|
||||
.all()
|
||||
)
|
||||
dataset_query = db.session.scalars(
|
||||
select(DatasetQuery).where(
|
||||
DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id
|
||||
)
|
||||
).all()
|
||||
|
||||
if not dataset_query or len(dataset_query) == 0:
|
||||
try:
|
||||
|
|
@ -121,15 +121,13 @@ def clean_unused_datasets_task():
|
|||
if should_clean:
|
||||
# Add auto disable log if required
|
||||
if add_logs:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
for document in documents:
|
||||
dataset_auto_disable_log = DatasetAutoDisableLog(
|
||||
tenant_id=dataset.tenant_id,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
from collections import defaultdict
|
||||
|
||||
import click
|
||||
from sqlalchemy import select
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
|
|
@ -31,9 +32,9 @@ def mail_clean_document_notify_task():
|
|||
|
||||
# send document clean notify mail
|
||||
try:
|
||||
dataset_auto_disable_logs = (
|
||||
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False).all()
|
||||
)
|
||||
dataset_auto_disable_logs = db.session.scalars(
|
||||
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False)
|
||||
).all()
|
||||
# group by tenant_id
|
||||
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
|
||||
for dataset_auto_disable_log in dataset_auto_disable_logs:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import time
|
||||
from collections.abc import Sequence
|
||||
|
||||
import click
|
||||
from sqlalchemy import select
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
|
|
@ -15,11 +17,9 @@ def update_tidb_serverless_status_task():
|
|||
start_at = time.perf_counter()
|
||||
try:
|
||||
# check the number of idle tidb serverless
|
||||
tidb_serverless_list = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
|
||||
.all()
|
||||
)
|
||||
tidb_serverless_list = db.session.scalars(
|
||||
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
|
||||
).all()
|
||||
if len(tidb_serverless_list) == 0:
|
||||
return
|
||||
# update tidb serverless status
|
||||
|
|
@ -32,7 +32,7 @@ def update_tidb_serverless_status_task():
|
|||
click.echo(click.style(f"Update tidb serverless status task success latency: {end_at - start_at}", fg="green"))
|
||||
|
||||
|
||||
def update_clusters(tidb_serverless_list: list[TidbAuthBinding]):
|
||||
def update_clusters(tidb_serverless_list: Sequence[TidbAuthBinding]):
|
||||
try:
|
||||
# batch 20
|
||||
for i in range(0, len(tidb_serverless_list), 20):
|
||||
|
|
|
|||
|
|
@ -246,6 +246,8 @@ class AccountService:
|
|||
account.name = name
|
||||
|
||||
if password:
|
||||
valid_password(password)
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
|
|
|||
|
|
@ -263,11 +263,9 @@ class AppAnnotationService:
|
|||
|
||||
db.session.delete(annotation)
|
||||
|
||||
annotation_hit_histories = (
|
||||
db.session.query(AppAnnotationHitHistory)
|
||||
.where(AppAnnotationHitHistory.annotation_id == annotation_id)
|
||||
.all()
|
||||
)
|
||||
annotation_hit_histories = db.session.scalars(
|
||||
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id)
|
||||
).all()
|
||||
if annotation_hit_histories:
|
||||
for annotation_hit_history in annotation_hit_histories:
|
||||
db.session.delete(annotation_hit_history)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import json
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
|
|
@ -9,11 +11,11 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
|||
class ApiKeyAuthService:
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str):
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
|
||||
.all()
|
||||
)
|
||||
data_source_api_key_bindings = db.session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
)
|
||||
).all()
|
||||
return data_source_api_key_bindings
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
|
||||
import click
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -115,7 +116,7 @@ class ClearFreePlanTenantExpiredLogs:
|
|||
@classmethod
|
||||
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
|
||||
with flask_app.app_context():
|
||||
apps = db.session.query(App).where(App.tenant_id == tenant_id).all()
|
||||
apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
|
||||
app_ids = [app.id for app in apps]
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import secrets
|
|||
import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
|
@ -741,14 +742,12 @@ class DatasetService:
|
|||
}
|
||||
# get recent 30 days auto disable logs
|
||||
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
dataset_auto_disable_logs = (
|
||||
db.session.query(DatasetAutoDisableLog)
|
||||
.where(
|
||||
dataset_auto_disable_logs = db.session.scalars(
|
||||
select(DatasetAutoDisableLog).where(
|
||||
DatasetAutoDisableLog.dataset_id == dataset_id,
|
||||
DatasetAutoDisableLog.created_at >= start_date,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if dataset_auto_disable_logs:
|
||||
return {
|
||||
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
|
||||
|
|
@ -885,69 +884,58 @@ class DocumentService:
|
|||
return document
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]:
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.id.in_(document_ids),
|
||||
Document.enabled == True,
|
||||
Document.indexing_status == "completed",
|
||||
Document.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]:
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.enabled == True,
|
||||
Document.indexing_status == "completed",
|
||||
Document.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
|
||||
.all()
|
||||
)
|
||||
def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
|
||||
).all()
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
|
||||
def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]:
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.batch == batch,
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return documents
|
||||
|
||||
|
|
@ -984,7 +972,7 @@ class DocumentService:
|
|||
# Check if document_ids is not empty to avoid WHERE false condition
|
||||
if not document_ids or len(document_ids) == 0:
|
||||
return
|
||||
documents = db.session.query(Document).where(Document.id.in_(document_ids)).all()
|
||||
documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all()
|
||||
file_ids = [
|
||||
document.data_source_info_dict["upload_file_id"]
|
||||
for document in documents
|
||||
|
|
@ -2424,16 +2412,14 @@ class SegmentService:
|
|||
if not segment_ids or len(segment_ids) == 0:
|
||||
return
|
||||
if action == "enable":
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
|
|
@ -2451,16 +2437,14 @@ class SegmentService:
|
|||
|
||||
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
elif action == "disable":
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
|
|
@ -2532,16 +2516,13 @@ class SegmentService:
|
|||
dataset: Dataset,
|
||||
) -> list[ChildChunk]:
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
child_chunks = db.session.scalars(
|
||||
select(ChildChunk).where(
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
ChildChunk.document_id == document.id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
child_chunks_map = {chunk.id: chunk for chunk in child_chunks}
|
||||
|
||||
new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], []
|
||||
|
|
@ -2751,19 +2732,13 @@ class DatasetCollectionBindingService:
|
|||
class DatasetPermissionService:
|
||||
@classmethod
|
||||
def get_dataset_partial_member_list(cls, dataset_id):
|
||||
user_list_query = (
|
||||
db.session.query(
|
||||
user_list_query = db.session.scalars(
|
||||
select(
|
||||
DatasetPermission.account_id,
|
||||
)
|
||||
.where(DatasetPermission.dataset_id == dataset_id)
|
||||
.all()
|
||||
)
|
||||
).where(DatasetPermission.dataset_id == dataset_id)
|
||||
).all()
|
||||
|
||||
user_list = []
|
||||
for user in user_list_query:
|
||||
user_list.append(user.account_id)
|
||||
|
||||
return user_list
|
||||
return user_list_query
|
||||
|
||||
@classmethod
|
||||
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
from json import JSONDecodeError
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities.provider_configuration import ProviderConfiguration
|
||||
|
|
@ -322,16 +322,14 @@ class ModelLoadBalancingService:
|
|||
if not isinstance(configs, list):
|
||||
raise ValueError("Invalid load balancing configs")
|
||||
|
||||
current_load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.where(
|
||||
current_load_balancing_configs = db.session.scalars(
|
||||
select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
# id as key, config as value
|
||||
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, RecommendedApp
|
||||
|
|
@ -31,18 +33,14 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
|||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
recommended_apps = (
|
||||
db.session.query(RecommendedApp)
|
||||
.where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
|
||||
.all()
|
||||
)
|
||||
recommended_apps = db.session.scalars(
|
||||
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
|
||||
).all()
|
||||
|
||||
if len(recommended_apps) == 0:
|
||||
recommended_apps = (
|
||||
db.session.query(RecommendedApp)
|
||||
.where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
|
||||
.all()
|
||||
)
|
||||
recommended_apps = db.session.scalars(
|
||||
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
|
||||
).all()
|
||||
|
||||
categories = set()
|
||||
recommended_apps_result = []
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import uuid
|
|||
from typing import Optional
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -29,35 +29,30 @@ class TagService:
|
|||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if not tag_ids or len(tag_ids) == 0:
|
||||
return []
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
.all()
|
||||
)
|
||||
tags = db.session.scalars(
|
||||
select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
).all()
|
||||
if not tags:
|
||||
return []
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if not tag_ids or len(tag_ids) == 0:
|
||||
return []
|
||||
tag_bindings = (
|
||||
db.session.query(TagBinding.target_id)
|
||||
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
|
||||
.all()
|
||||
)
|
||||
if not tag_bindings:
|
||||
return []
|
||||
results = [tag_binding.target_id for tag_binding in tag_bindings]
|
||||
return results
|
||||
tag_bindings = db.session.scalars(
|
||||
select(TagBinding.target_id).where(
|
||||
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id
|
||||
)
|
||||
).all()
|
||||
return tag_bindings
|
||||
|
||||
@staticmethod
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
|
||||
if not tag_type or not tag_name:
|
||||
return []
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
.all()
|
||||
tags = list(
|
||||
db.session.scalars(
|
||||
select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
).all()
|
||||
)
|
||||
if not tags:
|
||||
return []
|
||||
|
|
@ -117,7 +112,7 @@ class TagService:
|
|||
raise NotFound("Tag not found")
|
||||
db.session.delete(tag)
|
||||
# delete tag binding
|
||||
tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all()
|
||||
tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
|
||||
if tag_bindings:
|
||||
for tag_binding in tag_bindings:
|
||||
db.session.delete(tag_binding)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from collections.abc import Mapping
|
|||
from typing import Any, cast
|
||||
|
||||
from httpx import get
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -443,9 +444,7 @@ class ApiToolManageService:
|
|||
list api tools
|
||||
"""
|
||||
# get all api providers
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
|
||||
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from collections.abc import Mapping
|
|||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
|
@ -186,7 +186,9 @@ class WorkflowToolManageService:
|
|||
:param tenant_id: the tenant id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
db_tools = db.session.scalars(
|
||||
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
tools: list[WorkflowToolProviderController] = []
|
||||
for provider in db_tools:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
|
|
@ -39,7 +40,7 @@ def enable_annotation_reply_task(
|
|||
db.session.close()
|
||||
return
|
||||
|
||||
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all()
|
||||
annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
|
|
@ -34,7 +35,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
|||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all()
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||
).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
|
@ -59,7 +62,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
|||
|
||||
db.session.commit()
|
||||
if file_ids:
|
||||
files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
|
||||
files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
|
||||
for file in files:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
|
|
@ -55,8 +56,8 @@ def clean_dataset_task(
|
|||
index_struct=index_struct,
|
||||
collection_binding_id=collection_binding_id,
|
||||
)
|
||||
documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all()
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all()
|
||||
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
|
||||
|
||||
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
|
||||
# This ensures all invalid doc_form values are properly handled
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Optional
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
|
|
@ -35,7 +36,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
|||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -34,7 +35,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
|
|||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
db.session.delete(document)
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Literal
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
|
|
@ -36,16 +37,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
|
|||
if action == "remove":
|
||||
index_processor.clean(dataset, None, with_keywords=False)
|
||||
elif action == "add":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
|
|
@ -89,16 +88,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
|
|||
)
|
||||
db.session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -44,15 +45,13 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
|||
# sync index processor
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if not segments:
|
||||
db.session.close()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
|
|
@ -85,7 +86,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
|
|
@ -45,7 +46,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
|||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
|
|
@ -79,7 +80,9 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
|||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
|
|
@ -45,15 +46,13 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
|||
# sync index processor
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
|
||||
db.session.close()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -45,7 +46,7 @@ def remove_document_from_index_task(document_id: str):
|
|||
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all()
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
if index_node_ids:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
|
|
@ -69,7 +70,9 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
|||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
|
|
@ -63,7 +64,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
|||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
|
|
|
|||
|
|
@ -0,0 +1,101 @@
|
|||
"""Integration tests for ChatMessageApi permission verification."""
|
||||
|
||||
import uuid
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from controllers.console.app import completion as completion_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, Tenant
|
||||
from models.account import TenantAccountRole
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
|
||||
class TestChatMessageApiPermissions:
|
||||
"""Test permission verification for ChatMessageApi endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model for testing."""
|
||||
app = App()
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create a mock Account for testing."""
|
||||
|
||||
account = Account()
|
||||
account.id = str(uuid.uuid4())
|
||||
account.name = "Test User"
|
||||
account.email = "test@example.com"
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
|
||||
# Create mock tenant
|
||||
tenant = Tenant()
|
||||
tenant.id = str(uuid.uuid4())
|
||||
tenant.name = "Test Tenant"
|
||||
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_post_with_owner_role_succeeds(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Test that OWNER role can access chat-messages endpoint."""
|
||||
|
||||
"""Setup common mocks for testing."""
|
||||
# Mock app loading
|
||||
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Mock current user
|
||||
monkeypatch.setattr(completion_api, "current_user", mock_account)
|
||||
|
||||
mock_generate = mock.Mock(return_value={"message": "Test response"})
|
||||
monkeypatch.setattr(AppGenerateService, "generate", mock_generate)
|
||||
|
||||
# Set user role to OWNER
|
||||
mock_account.role = role
|
||||
|
||||
response = test_client.post(
|
||||
f"/console/api/apps/{mock_app_model.id}/chat-messages",
|
||||
headers=auth_header,
|
||||
json={
|
||||
"inputs": {},
|
||||
"query": "Hello, world!",
|
||||
"model_config": {
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}}
|
||||
},
|
||||
"response_mode": "blocking",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
"""Integration tests for ModelConfigResource permission verification."""
|
||||
|
||||
import uuid
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from controllers.console.app import model_config as model_config_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, Tenant
|
||||
from models.account import TenantAccountRole
|
||||
from models.model import AppMode
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
class TestModelConfigResourcePermissions:
|
||||
"""Test permission verification for ModelConfigResource endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model for testing."""
|
||||
app = App()
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.app_model_config_id = str(uuid.uuid4())
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create a mock Account for testing."""
|
||||
|
||||
account = Account()
|
||||
account.id = str(uuid.uuid4())
|
||||
account.name = "Test User"
|
||||
account.email = "test@example.com"
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
|
||||
# Create mock tenant
|
||||
tenant = Tenant()
|
||||
tenant.id = str(uuid.uuid4())
|
||||
tenant.name = "Test Tenant"
|
||||
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_post_with_owner_role_succeeds(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Test that OWNER role can access model-config endpoint."""
|
||||
# Set user role to OWNER
|
||||
mock_account.role = role
|
||||
|
||||
# Mock app loading
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Mock current user
|
||||
monkeypatch.setattr(model_config_api, "current_user", mock_account)
|
||||
|
||||
# Mock AccountService.load_user to prevent authentication issues
|
||||
from services.account_service import AccountService
|
||||
|
||||
mock_load_user = mock.Mock(return_value=mock_account)
|
||||
monkeypatch.setattr(AccountService, "load_user", mock_load_user)
|
||||
|
||||
mock_validate_config = mock.Mock(
|
||||
return_value={
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}},
|
||||
"pre_prompt": "You are a helpful assistant.",
|
||||
"user_input_form": [],
|
||||
"dataset_query_variable": "",
|
||||
"agent_mode": {"enabled": False, "tools": []},
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(AppModelConfigService, "validate_configuration", mock_validate_config)
|
||||
|
||||
# Mock database operations
|
||||
mock_db_session = mock.Mock()
|
||||
mock_db_session.add = mock.Mock()
|
||||
mock_db_session.flush = mock.Mock()
|
||||
mock_db_session.commit = mock.Mock()
|
||||
monkeypatch.setattr(model_config_api.db, "session", mock_db_session)
|
||||
|
||||
# Mock app_model_config_was_updated event
|
||||
mock_event = mock.Mock()
|
||||
mock_event.send = mock.Mock()
|
||||
monkeypatch.setattr(model_config_api, "app_model_config_was_updated", mock_event)
|
||||
|
||||
response = test_client.post(
|
||||
f"/console/api/apps/{mock_app_model.id}/model-config",
|
||||
headers=auth_header,
|
||||
json={
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4",
|
||||
"mode": "chat",
|
||||
"completion_params": {"temperature": 0.7, "max_tokens": 1000},
|
||||
},
|
||||
"user_input_form": [],
|
||||
"dataset_query_variable": "",
|
||||
"pre_prompt": "You are a helpful assistant.",
|
||||
"agent_mode": {"enabled": False, "tools": []},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
|
|
@ -91,6 +91,28 @@ class TestAccountService:
|
|||
assert account.password is None
|
||||
assert account.password_salt is None
|
||||
|
||||
def test_create_account_password_invalid_new_password(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test account create with invalid new password format.
|
||||
"""
|
||||
fake = Faker()
|
||||
email = fake.email()
|
||||
name = fake.name()
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
|
||||
# Test with too short password (assuming minimum length validation)
|
||||
with pytest.raises(ValueError): # Password validation error
|
||||
AccountService.create_account(
|
||||
email=email,
|
||||
name=name,
|
||||
interface_language="en-US",
|
||||
password="invalid_new_password",
|
||||
)
|
||||
|
||||
def test_create_account_registration_disabled(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test account creation when registration is disabled.
|
||||
|
|
@ -940,7 +962,8 @@ class TestAccountService:
|
|||
Test getting user through non-existent email.
|
||||
"""
|
||||
fake = Faker()
|
||||
non_existent_email = fake.email()
|
||||
domain = f"test-{fake.random_letters(10)}.com"
|
||||
non_existent_email = fake.email(domain=domain)
|
||||
found_user = AccountService.get_user_through_email(non_existent_email)
|
||||
assert found_user is None
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
from models.model import Account, Tenant
|
||||
|
|
@ -468,7 +469,7 @@ class TestModelLoadBalancingService:
|
|||
assert load_balancing_config.id is not None
|
||||
|
||||
# Verify inherit config was created in database
|
||||
inherit_configs = (
|
||||
db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all()
|
||||
)
|
||||
inherit_configs = db.session.scalars(
|
||||
select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__")
|
||||
).all()
|
||||
assert len(inherit_configs) == 1
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
|
|||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
|
@ -954,7 +955,9 @@ class TestTagService:
|
|||
from extensions.ext_database import db
|
||||
|
||||
# Verify only one binding exists
|
||||
bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all()
|
||||
bindings = db.session.scalars(
|
||||
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
|
||||
).all()
|
||||
assert len(bindings) == 1
|
||||
|
||||
def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
|
|
@ -1064,7 +1067,9 @@ class TestTagService:
|
|||
# No error should be raised, and database state should remain unchanged
|
||||
from extensions.ext_database import db
|
||||
|
||||
bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all()
|
||||
bindings = db.session.scalars(
|
||||
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
|
||||
).all()
|
||||
assert len(bindings) == 0
|
||||
|
||||
def test_check_target_exists_knowledge_success(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.account import Account
|
||||
|
|
@ -354,16 +355,14 @@ class TestWebConversationService:
|
|||
# Verify only one pinned conversation record exists
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversations = (
|
||||
db.session.query(PinnedConversation)
|
||||
.where(
|
||||
pinned_conversations = db.session.scalars(
|
||||
select(PinnedConversation).where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
PinnedConversation.created_by_role == "account",
|
||||
PinnedConversation.created_by == account.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
assert len(pinned_conversations) == 1
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -246,6 +246,43 @@ class TestEmailI18nService:
|
|||
sent_email = mock_sender.sent_emails[0]
|
||||
assert sent_email["subject"] == "Reset Your Dify Password"
|
||||
|
||||
def test_subject_format_keyerror_fallback_path(
|
||||
self,
|
||||
mock_renderer: MockEmailRenderer,
|
||||
mock_sender: MockEmailSender,
|
||||
):
|
||||
"""Trigger subject KeyError and cover except branch."""
|
||||
# Config with subject that references an unknown key (no {application_title} to avoid second format)
|
||||
config = EmailI18nConfig(
|
||||
templates={
|
||||
EmailType.INVITE_MEMBER: {
|
||||
EmailLanguage.EN_US: EmailTemplate(
|
||||
subject="Invite: {unknown_placeholder}",
|
||||
template_path="invite_member_en.html",
|
||||
branded_template_path="branded/invite_member_en.html",
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
branding_service = MockBrandingService(enabled=False)
|
||||
service = EmailI18nService(
|
||||
config=config,
|
||||
renderer=mock_renderer,
|
||||
branding_service=branding_service,
|
||||
sender=mock_sender,
|
||||
)
|
||||
|
||||
# Will raise KeyError on subject.format(**full_context), then hit except branch and skip fallback
|
||||
service.send_email(
|
||||
email_type=EmailType.INVITE_MEMBER,
|
||||
language_code="en-US",
|
||||
to="test@example.com",
|
||||
)
|
||||
|
||||
assert len(mock_sender.sent_emails) == 1
|
||||
# Subject is left unformatted due to KeyError fallback path without application_title
|
||||
assert mock_sender.sent_emails[0]["subject"] == "Invite: {unknown_placeholder}"
|
||||
|
||||
def test_send_change_email_old_phase(
|
||||
self,
|
||||
email_config: EmailI18nConfig,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,122 @@
|
|||
from flask import Blueprint, Flask
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, Unauthorized
|
||||
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
|
||||
def _create_api_app():
|
||||
app = Flask(__name__)
|
||||
bp = Blueprint("t", __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
@api.route("/bad-request")
|
||||
class Bad(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise BadRequest("invalid input")
|
||||
|
||||
@api.route("/unauth")
|
||||
class Unauth(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise Unauthorized("auth required")
|
||||
|
||||
@api.route("/value-error")
|
||||
class ValErr(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise ValueError("boom")
|
||||
|
||||
@api.route("/quota")
|
||||
class Quota(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise AppInvokeQuotaExceededError("quota exceeded")
|
||||
|
||||
@api.route("/general")
|
||||
class Gen(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise RuntimeError("oops")
|
||||
|
||||
# Note: We avoid altering default_mediatype to keep normal error paths
|
||||
|
||||
# Special 400 message rewrite
|
||||
@api.route("/json-empty")
|
||||
class JsonEmpty(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
e = BadRequest()
|
||||
# Force the specific message the handler rewrites
|
||||
e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
|
||||
raise e
|
||||
|
||||
# 400 mapping payload path
|
||||
@api.route("/param-errors")
|
||||
class ParamErrors(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
e = BadRequest()
|
||||
# Coerce a mapping description to trigger param error shaping
|
||||
e.description = {"field": "is required"} # type: ignore[assignment]
|
||||
raise e
|
||||
|
||||
app.register_blueprint(bp, url_prefix="/api")
|
||||
return app
|
||||
|
||||
|
||||
def test_external_api_error_handlers_basic_paths():
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# 400
|
||||
res = client.get("/api/bad-request")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "bad_request"
|
||||
assert data["status"] == 400
|
||||
|
||||
# 401
|
||||
res = client.get("/api/unauth")
|
||||
assert res.status_code == 401
|
||||
assert "WWW-Authenticate" in res.headers
|
||||
|
||||
# 400 ValueError
|
||||
res = client.get("/api/value-error")
|
||||
assert res.status_code == 400
|
||||
assert res.get_json()["code"] == "invalid_param"
|
||||
|
||||
# 500 general
|
||||
res = client.get("/api/general")
|
||||
assert res.status_code == 500
|
||||
assert res.get_json()["status"] == 500
|
||||
|
||||
|
||||
def test_external_api_json_message_and_bad_request_rewrite():
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# JSON empty special rewrite
|
||||
res = client.get("/api/json-empty")
|
||||
assert res.status_code == 400
|
||||
assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty."
|
||||
|
||||
|
||||
def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
||||
# Force exc_info() to return (None,None,None) only during request
|
||||
import libs.external_api as ext
|
||||
|
||||
orig_exc_info = ext.sys.exc_info
|
||||
try:
|
||||
ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment]
|
||||
|
||||
app = _create_api_app()
|
||||
client = app.test_client()
|
||||
|
||||
# Param errors mapping payload path
|
||||
res = client.get("/api/param-errors")
|
||||
assert res.status_code == 400
|
||||
data = res.get_json()
|
||||
assert data["code"] == "invalid_param"
|
||||
assert data["params"] == "field"
|
||||
|
||||
# Quota path — depending on Flask-RESTX internals it may be handled
|
||||
res = client.get("/api/quota")
|
||||
assert res.status_code in (400, 429)
|
||||
finally:
|
||||
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
import pytest
|
||||
|
||||
from libs.oauth import OAuth
|
||||
|
||||
|
||||
def test_oauth_base_methods_raise_not_implemented():
|
||||
oauth = OAuth(client_id="id", client_secret="sec", redirect_uri="uri")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth.get_authorization_url()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth.get_access_token("code")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth.get_raw_user_info("token")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth._transform_user_info({}) # type: ignore[name-defined]
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from python_http_client.exceptions import UnauthorizedError
|
||||
|
||||
from libs.sendgrid import SendGridClient
|
||||
|
||||
|
||||
def _mail(to: str = "user@example.com") -> dict:
|
||||
return {"to": to, "subject": "Hi", "html": "<b>Hi</b>"}
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_success(mock_client_cls: MagicMock):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
# nested attribute access: client.mail.send.post
|
||||
mock_client.client.mail.send.post.return_value = MagicMock(status_code=202, body=b"", headers={})
|
||||
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
sg.send(_mail())
|
||||
|
||||
mock_client_cls.assert_called_once()
|
||||
mock_client.client.mail.send.post.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_missing_to_raises(mock_client_cls: MagicMock):
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
with pytest.raises(ValueError):
|
||||
sg.send(_mail(to=""))
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_auth_errors_reraise(mock_client_cls: MagicMock):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.client.mail.send.post.side_effect = UnauthorizedError(401, "Unauthorized", b"{}", {})
|
||||
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
with pytest.raises(UnauthorizedError):
|
||||
sg.send(_mail())
|
||||
|
||||
|
||||
@patch("libs.sendgrid.sendgrid.SendGridAPIClient")
|
||||
def test_sendgrid_timeout_reraise(mock_client_cls: MagicMock):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.client.mail.send.post.side_effect = TimeoutError("timeout")
|
||||
|
||||
sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com")
|
||||
with pytest.raises(TimeoutError):
|
||||
sg.send(_mail())
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.smtp import SMTPClient
|
||||
|
||||
|
||||
def _mail() -> dict:
|
||||
return {"to": "user@example.com", "subject": "Hi", "html": "<b>Hi</b>"}
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_plain_success(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
|
||||
client.send(_mail())
|
||||
|
||||
mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10)
|
||||
mock_smtp.sendmail.assert_called_once()
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(
|
||||
server="smtp.example.com",
|
||||
port=587,
|
||||
username="user",
|
||||
password="pass",
|
||||
_from="noreply@example.com",
|
||||
use_tls=True,
|
||||
opportunistic_tls=True,
|
||||
)
|
||||
client.send(_mail())
|
||||
|
||||
mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10)
|
||||
assert mock_smtp.ehlo.call_count == 2
|
||||
mock_smtp.starttls.assert_called_once()
|
||||
mock_smtp.login.assert_called_once_with("user", "pass")
|
||||
mock_smtp.sendmail.assert_called_once()
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP_SSL")
|
||||
def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock):
|
||||
# Cover SMTP_SSL branch and TimeoutError handling
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.sendmail.side_effect = TimeoutError("timeout")
|
||||
mock_smtp_ssl_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(
|
||||
server="smtp.example.com",
|
||||
port=465,
|
||||
username="",
|
||||
password="",
|
||||
_from="noreply@example.com",
|
||||
use_tls=True,
|
||||
opportunistic_tls=False,
|
||||
)
|
||||
with pytest.raises(TimeoutError):
|
||||
client.send(_mail())
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.sendmail.side_effect = RuntimeError("oops")
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
|
||||
with pytest.raises(RuntimeError):
|
||||
client.send(_mail())
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
def test_smtp_smtplib_exception_in_login(mock_smtp_cls: MagicMock):
|
||||
# Ensure we hit the specific SMTPException except branch
|
||||
import smtplib
|
||||
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.login.side_effect = smtplib.SMTPException("login-fail")
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
client = SMTPClient(
|
||||
server="smtp.example.com",
|
||||
port=25,
|
||||
username="user", # non-empty to trigger login
|
||||
password="pass",
|
||||
_from="noreply@example.com",
|
||||
)
|
||||
with pytest.raises(smtplib.SMTPException):
|
||||
client.send(_mail())
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
|
@ -28,18 +28,20 @@ class TestApiKeyAuthService:
|
|||
mock_binding.provider = self.provider
|
||||
mock_binding.disabled = False
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [mock_binding]
|
||||
mock_session.scalars.return_value.all.return_value = [mock_binding]
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].tenant_id == self.tenant_id
|
||||
mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
|
||||
assert mock_session.scalars.call_count == 1
|
||||
select_arg = mock_session.scalars.call_args[0][0]
|
||||
assert "data_source_api_key_auth_binding" in str(select_arg).lower()
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_empty(self, mock_session):
|
||||
"""Test get provider auth list - empty result"""
|
||||
mock_session.query.return_value.where.return_value.all.return_value = []
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
|
|
@ -48,13 +50,15 @@ class TestApiKeyAuthService:
|
|||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
def test_get_provider_auth_list_filters_disabled(self, mock_session):
|
||||
"""Test get provider auth list - filters disabled items"""
|
||||
mock_session.query.return_value.where.return_value.all.return_value = []
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
|
||||
|
||||
# Verify where conditions include disabled.is_(False)
|
||||
where_call = mock_session.query.return_value.where.call_args[0]
|
||||
assert len(where_call) == 2 # tenant_id and disabled filter conditions
|
||||
select_stmt = mock_session.scalars.call_args[0][0]
|
||||
where_clauses = list(getattr(select_stmt, "_where_criteria", []) or [])
|
||||
# Ensure both tenant filter and disabled filter exist
|
||||
where_strs = [str(c).lower() for c in where_clauses]
|
||||
assert any("tenant_id" in s for s in where_strs)
|
||||
assert any("disabled" in s for s in where_strs)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.db.session")
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
|
|
|
|||
|
|
@ -63,10 +63,10 @@ class TestAuthIntegration:
|
|||
tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
|
||||
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding]
|
||||
mock_session.scalars.return_value.all.return_value = [tenant1_binding]
|
||||
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)
|
||||
|
||||
mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding]
|
||||
mock_session.scalars.return_value.all.return_value = [tenant2_binding]
|
||||
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)
|
||||
|
||||
assert len(result1) == 1
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ const GA: FC<IGAProps> = ({
|
|||
if (IS_CE_EDITION)
|
||||
return null
|
||||
|
||||
const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') : ''
|
||||
const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') ?? '' : ''
|
||||
|
||||
return (
|
||||
<>
|
||||
|
|
@ -32,7 +32,7 @@ const GA: FC<IGAProps> = ({
|
|||
strategy="beforeInteractive"
|
||||
async
|
||||
src={`https://www.googletagmanager.com/gtag/js?id=${gaIdMaps[gaType]}`}
|
||||
nonce={nonce!}
|
||||
nonce={nonce ?? undefined}
|
||||
></Script>
|
||||
<Script
|
||||
id="ga-init"
|
||||
|
|
@ -44,14 +44,14 @@ gtag('js', new Date());
|
|||
gtag('config', '${gaIdMaps[gaType]}');
|
||||
`,
|
||||
}}
|
||||
nonce={nonce!}
|
||||
nonce={nonce ?? undefined}
|
||||
>
|
||||
</Script>
|
||||
{/* Cookie banner */}
|
||||
<Script
|
||||
id="cookieyes"
|
||||
src='https://cdn-cookieyes.com/client_data/2a645945fcae53f8e025a2b1/script.js'
|
||||
nonce={nonce!}
|
||||
nonce={nonce ?? undefined}
|
||||
></Script>
|
||||
</>
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import React, { useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useChatContext } from '../chat/chat/context'
|
||||
|
||||
const hasEndThink = (children: any): boolean => {
|
||||
if (typeof children === 'string')
|
||||
|
|
@ -35,6 +36,7 @@ const removeEndThink = (children: any): any => {
|
|||
}
|
||||
|
||||
const useThinkTimer = (children: any) => {
|
||||
const { isResponding } = useChatContext()
|
||||
const [startTime] = useState(Date.now())
|
||||
const [elapsedTime, setElapsedTime] = useState(0)
|
||||
const [isComplete, setIsComplete] = useState(false)
|
||||
|
|
@ -54,9 +56,9 @@ const useThinkTimer = (children: any) => {
|
|||
}, [startTime, isComplete])
|
||||
|
||||
useEffect(() => {
|
||||
if (hasEndThink(children))
|
||||
if (hasEndThink(children) || !isResponding)
|
||||
setIsComplete(true)
|
||||
}, [children])
|
||||
}, [children, isResponding])
|
||||
|
||||
return { elapsedTime, isComplete }
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
import { memo } from 'react'
|
||||
import { type UnsafeUnwrappedHeaders, headers } from 'next/headers'
|
||||
import Script from 'next/script'
|
||||
import { IS_CE_EDITION, ZENDESK_WIDGET_KEY } from '@/config'
|
||||
|
||||
const Zendesk = () => {
|
||||
if (IS_CE_EDITION || !ZENDESK_WIDGET_KEY)
|
||||
return null
|
||||
|
||||
const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') ?? '' : ''
|
||||
|
||||
return (
|
||||
<Script
|
||||
nonce={nonce ?? undefined}
|
||||
id="ze-snippet"
|
||||
src={`https://static.zdassets.com/ekr/snippet.js?key=${ZENDESK_WIDGET_KEY}`}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(Zendesk)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
import { IS_CE_EDITION } from '@/config'
|
||||
|
||||
export type ConversationField = {
|
||||
id: string,
|
||||
value: any,
|
||||
}
|
||||
|
||||
declare global {
|
||||
// eslint-disable-next-line ts/consistent-type-definitions
|
||||
interface Window {
|
||||
zE?: (
|
||||
command: string,
|
||||
value: string,
|
||||
payload?: ConversationField[] | string | string[] | (() => any),
|
||||
callback?: () => any,
|
||||
) => void;
|
||||
}
|
||||
}
|
||||
|
||||
export const setZendeskConversationFields = (fields: ConversationField[], callback?: () => any) => {
|
||||
if (!IS_CE_EDITION && window.zE)
|
||||
window.zE('messenger:set', 'conversationFields', fields, callback)
|
||||
}
|
||||
|
|
@ -38,7 +38,7 @@ const Field: FC<Props> = ({
|
|||
<div className={cn(className, inline && 'flex w-full items-center justify-between')}>
|
||||
<div
|
||||
onClick={() => supportFold && toggleFold()}
|
||||
className={cn('sticky top-0 flex items-center justify-between bg-components-panel-bg', supportFold && 'cursor-pointer')}>
|
||||
className={cn('flex items-center justify-between', supportFold && 'cursor-pointer')}>
|
||||
<div className='flex h-6 items-center'>
|
||||
<div className={cn(isSubTitle ? 'system-xs-medium-uppercase text-text-tertiary' : 'system-sm-semibold-uppercase text-text-secondary')}>
|
||||
{title} {required && <span className='text-text-destructive'>*</span>}
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue