mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/mcp-06-18
This commit is contained in:
commit
d5a7a537e5
|
|
@ -189,6 +189,11 @@ class PluginConfig(BaseSettings):
|
|||
default="plugin-api-key",
|
||||
)
|
||||
|
||||
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||
default=300.0,
|
||||
)
|
||||
|
||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||
|
||||
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import flask_restx
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import select
|
||||
|
|
@ -7,8 +8,7 @@ from werkzeug.exceptions import Forbidden
|
|||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
|
|
@ -57,9 +57,9 @@ class BaseApiKeyListResource(Resource):
|
|||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
|
|
@ -71,9 +71,8 @@ class BaseApiKeyListResource(Resource):
|
|||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -93,7 +92,7 @@ class BaseApiKeyListResource(Resource):
|
|||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_user.current_tenant_id
|
||||
api_token.tenant_id = current_tenant_id
|
||||
api_token.token = key
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
|
|
@ -112,9 +111,8 @@ class BaseApiKeyResource(Resource):
|
|||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
api_key_id = str(api_key_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
|
|
@ -158,7 +156,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
|||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
def after_request(self, resp: Response):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
|
@ -208,7 +206,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
|||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
def after_request(self, resp: Response):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
|
@ -229,7 +227,7 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
|||
"""Delete an API key for a dataset"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
def after_request(self, resp: Response):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -17,7 +16,7 @@ from fields.annotation_fields import (
|
|||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
|
|
@ -43,7 +42,9 @@ class AnnotationReplyActionApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -70,7 +71,9 @@ class AppAnnotationSettingDetailApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -99,7 +102,9 @@ class AppAnnotationSettingUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id, annotation_setting_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -125,7 +130,9 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
def get(self, app_id, job_id, action):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
|
|
@ -160,7 +167,9 @@ class AnnotationApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
|
|
@ -199,7 +208,9 @@ class AnnotationApi(Resource):
|
|||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -214,7 +225,9 @@ class AnnotationApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -250,7 +263,9 @@ class AnnotationExportApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -273,7 +288,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -289,7 +306,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -311,7 +330,9 @@ class AnnotationBatchImportApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -342,7 +363,9 @@ class AnnotationBatchImportStatusApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
def get(self, app_id, job_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
|
|
@ -377,7 +400,9 @@ class AnnotationHitHistoryListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,3 @@
|
|||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
|
@ -13,8 +10,7 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
|
@ -32,7 +28,8 @@ class AppImportApi(Resource):
|
|||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -51,7 +48,7 @@ class AppImportApi(Resource):
|
|||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = cast(Account, current_user)
|
||||
account = current_user
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
|
|
@ -85,14 +82,15 @@ class AppImportConfirmApi(Resource):
|
|||
@marshal_with(app_import_fields)
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
account = cast(Account, current_user)
|
||||
account = current_user
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
|
|
@ -110,7 +108,8 @@ class AppImportCheckDependenciesApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(app_import_check_dependencies_fields)
|
||||
def get(self, app_model: App):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
|
|
@ -17,7 +16,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
|
|||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
|
@ -48,11 +47,11 @@ class RuleGenerateApi(Resource):
|
|||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=account.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=args["no_variable"],
|
||||
|
|
@ -99,11 +98,11 @@ class RuleCodeGenerateApi(Resource):
|
|||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=account.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["code_language"],
|
||||
|
|
@ -144,11 +143,11 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=account.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
)
|
||||
|
|
@ -198,6 +197,7 @@ class InstructionGenerateApi(Resource):
|
|||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
code_template = (
|
||||
Python3CodeProvider.get_default_code()
|
||||
if args["language"] == "python"
|
||||
|
|
@ -222,21 +222,21 @@ class InstructionGenerateApi(Resource):
|
|||
match node_type:
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=True,
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=True,
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["language"],
|
||||
|
|
@ -245,7 +245,7 @@ class InstructionGenerateApi(Resource):
|
|||
return {"error": f"invalid node type: {node_type}"}
|
||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
||||
return LLMGenerator.instruction_modify_legacy(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
|
|
@ -254,7 +254,7 @@ class InstructionGenerateApi(Resource):
|
|||
)
|
||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
||||
return LLMGenerator.instruction_modify_workflow(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
node_id=args["node_id"],
|
||||
current=args["current"],
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import json
|
|||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -15,8 +14,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
|
|||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
|
@ -54,16 +52,14 @@ class ModelConfigResource(Resource):
|
|||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model):
|
||||
"""Modify app model config"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
config=cast(dict, request.json),
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
|
@ -95,12 +91,12 @@ class ModelConfigResource(Resource):
|
|||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
|
|
@ -134,7 +130,7 @@ class ModelConfigResource(Resource):
|
|||
else:
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
|
|
@ -142,7 +138,7 @@ class ModelConfigResource(Resource):
|
|||
continue
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
|
@ -9,7 +8,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account, Site
|
||||
|
||||
|
||||
|
|
@ -76,9 +75,10 @@ class AppSite(Resource):
|
|||
@marshal_with(app_site_fields)
|
||||
def post(self, app_model):
|
||||
args = parse_app_site_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
|
@ -131,6 +131,8 @@ class AppSiteAccessTokenReset(Resource):
|
|||
@marshal_with(app_site_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
|
@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
"sources": [
|
||||
|
|
@ -41,6 +41,8 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
|||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -50,7 +52,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
|||
args = parser.parse_args()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
|
|
@ -63,9 +65,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
|||
@account_initialization_required
|
||||
def delete(self, binding_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from collections.abc import Generator
|
|||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -20,7 +19,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
|
|||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
|
@ -37,10 +36,12 @@ class DataSourceApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(integrate_list_fields)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.tenant_id == current_tenant_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
)
|
||||
).all()
|
||||
|
|
@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(integrate_notion_info_list_fields)
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
|
|
@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource):
|
|||
documents = session.scalars(
|
||||
select(Document).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
)
|
||||
|
|
@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource):
|
|||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id="langgenius/notion_datasource/notion_datasource",
|
||||
datasource_name="notion_datasource",
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
|
@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
|
|
@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource):
|
|||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
text_docs = extractor.extract()
|
||||
|
|
@ -239,6 +244,8 @@ class DataSourceNotionApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
|
|
@ -263,7 +270,7 @@ class DataSourceNotionApi(Resource):
|
|||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
|
|
@ -271,7 +278,7 @@ class DataSourceNotionApi(Resource):
|
|||
extract_settings.append(extract_setting)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
extract_settings,
|
||||
args["process_rule"],
|
||||
args["doc_form"],
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import uuid
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
|
@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
|
|
@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -79,7 +80,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id,
|
||||
DocumentSegment.tenant_id == current_tenant_id,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
)
|
||||
|
|
@ -115,6 +116,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -148,6 +151,8 @@ class DatasetDocumentSegmentApi(Resource):
|
|||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
|
|
@ -171,7 +176,7 @@ class DatasetDocumentSegmentApi(Resource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -204,6 +209,8 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -221,7 +228,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -255,6 +262,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -272,7 +281,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -287,7 +296,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -317,6 +326,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -333,7 +344,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -361,6 +372,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -396,7 +409,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||
upload_file_id,
|
||||
dataset_id,
|
||||
document_id,
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
current_user.id,
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -427,6 +440,8 @@ class ChildChunkAddApi(Resource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -441,7 +456,7 @@ class ChildChunkAddApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -453,7 +468,7 @@ class ChildChunkAddApi(Resource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -483,6 +498,8 @@ class ChildChunkAddApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id, segment_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -499,7 +516,7 @@ class ChildChunkAddApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -530,6 +547,8 @@ class ChildChunkAddApi(Resource):
|
|||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -546,7 +565,7 @@ class ChildChunkAddApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -580,6 +599,8 @@ class ChildChunkUpdateApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -596,7 +617,7 @@ class ChildChunkUpdateApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -607,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
|
|||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
|
|
@ -634,6 +655,8 @@ class ChildChunkUpdateApi(Resource):
|
|||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -650,7 +673,7 @@ class ChildChunkUpdateApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -661,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
|
|||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from flask import make_response, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
|
@ -13,7 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
|
@ -25,9 +24,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
if not current_user.is_editor:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
credential_id = request.args.get("credential_id")
|
||||
|
|
@ -52,7 +52,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
|||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=current_user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
|
|
@ -131,7 +131,9 @@ class DatasourceAuth(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -145,7 +147,7 @@ class DatasourceAuth(Resource):
|
|||
|
||||
try:
|
||||
datasource_provider_service.add_datasource_api_key_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
|
|
@ -160,8 +162,10 @@ class DatasourceAuth(Resource):
|
|||
def get(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasources = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
|
|
@ -174,17 +178,19 @@ class DatasourceAuthDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
|
|
@ -198,17 +204,19 @@ class DatasourceAuthUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
|
|
@ -224,10 +232,10 @@ class DatasourceAuthListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
|
|
@ -237,10 +245,10 @@ class DatasourceHardCodeAuthListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
|
|
@ -250,7 +258,9 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
|
|
@ -259,7 +269,7 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
client_params=args.get("client_params", {}),
|
||||
enabled=args.get("enable_oauth_custom_client", False),
|
||||
|
|
@ -270,10 +280,12 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
|
@ -285,7 +297,9 @@ class DatasourceAuthDefaultApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
|
|
@ -293,7 +307,7 @@ class DatasourceAuthDefaultApi(Resource):
|
|||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
credential_id=args["id"],
|
||||
)
|
||||
|
|
@ -306,7 +320,9 @@ class DatasourceUpdateProviderNameApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
|
|
@ -315,7 +331,7 @@ class DatasourceUpdateProviderNameApi(Resource):
|
|||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
name=args["name"],
|
||||
credential_id=args["credential_id"],
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
|
@ -13,7 +12,7 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
|
|
@ -38,7 +37,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
|
@ -58,12 +57,12 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
import_info["dataset_id"],
|
||||
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||
)
|
||||
|
|
@ -81,10 +80,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
|||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
||||
name="",
|
||||
description="",
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
|||
from extensions.ext_database import db
|
||||
from fields.installed_app_fields import installed_app_list_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, InstalledApp, RecommendedApp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
|
@ -29,9 +29,7 @@ class InstalledAppsListApi(Resource):
|
|||
@marshal_with(installed_app_list_fields)
|
||||
def get(self):
|
||||
app_id = request.args.get("app_id", default=None, type=str)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if app_id:
|
||||
installed_apps = db.session.scalars(
|
||||
|
|
@ -121,9 +119,8 @@ class InstalledAppsListApi(Resource):
|
|||
if recommended_app is None:
|
||||
raise NotFound("App not found")
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
||||
|
||||
if app is None:
|
||||
|
|
@ -163,9 +160,8 @@ class InstalledAppApi(InstalledAppResource):
|
|||
"""
|
||||
|
||||
def delete(self, installed_app):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
|
||||
db.session.delete(installed_app)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from constants import HIDDEN_VALUE
|
|||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from models.account import Account
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
|
|
@ -47,9 +47,7 @@ class APIBasedExtensionAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
|
||||
@api.doc("create_api_based_extension")
|
||||
|
|
@ -77,9 +75,10 @@ class APIBasedExtensionAPI(Resource):
|
|||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||
parser.add_argument("api_key", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
name=args["name"],
|
||||
api_endpoint=args["api_endpoint"],
|
||||
api_key=args["api_key"],
|
||||
|
|
@ -102,7 +101,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
|
||||
|
|
@ -128,9 +127,9 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
|
|
@ -157,9 +156,9 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
APIBasedExtensionService.delete(extension_data_from_db)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from flask_restx import Resource, fields
|
||||
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api, console_ns
|
||||
|
|
@ -23,9 +22,9 @@ class FeatureApi(Resource):
|
|||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get feature configuration for current tenant"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_features(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/system-features")
|
||||
|
|
|
|||
|
|
@ -108,4 +108,4 @@ class FileSupportTypeApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
|
||||
return {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}
|
||||
|
|
|
|||
|
|
@ -14,8 +14,7 @@ from core.file import helpers as file_helpers
|
|||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
|
|
@ -64,8 +63,7 @@ class RemoteFileUploadApi(Resource):
|
|||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
user, _ = current_account_with_tenant()
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
|
|
|
|||
|
|
@ -5,18 +5,10 @@ from controllers.console import api, console_ns
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
def _current_account_with_tenant() -> tuple[Account, str]:
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
assert tenant_id is not None
|
||||
return current_user, tenant_id
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
@api.doc("create_endpoint")
|
||||
|
|
@ -41,7 +33,7 @@ class EndpointCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -87,7 +79,7 @@ class EndpointListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -130,7 +122,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -172,7 +164,7 @@ class EndpointDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -212,7 +204,7 @@ class EndpointUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -255,7 +247,7 @@ class EndpointEnableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -288,7 +280,7 @@ class EndpointDisableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from controllers.console.wraps import (
|
|||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_list_fields
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
|
@ -41,8 +41,7 @@ class MemberListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_fields)
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
|
|
@ -69,9 +68,7 @@ class MemberInviteEmailApi(Resource):
|
|||
interface_language = args["language"]
|
||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
inviter = current_user
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
|
|
@ -120,8 +117,7 @@ class MemberCancelInviteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, member_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
||||
|
|
@ -160,9 +156,7 @@ class MemberUpdateRoleApi(Resource):
|
|||
|
||||
if not TenantAccountRole.is_valid_role(new_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
member = db.session.get(Account, str(member_id))
|
||||
|
|
@ -189,8 +183,7 @@ class DatasetOperatorMemberListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_fields)
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
|
|
@ -212,10 +205,8 @@ class SendOwnerTransferEmailApi(Resource):
|
|||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# check if the current user is the owner of the workspace
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
|
|
@ -250,8 +241,7 @@ class OwnerTransferCheckApi(Resource):
|
|||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
# check if the current user is the owner of the workspace
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
|
|
@ -296,8 +286,7 @@ class OwnerTransfer(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
# check if the current user is the owner of the workspace
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import io
|
||||
|
||||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -11,8 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
|
@ -23,11 +21,8 @@ class ModelProviderListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
|
|
@ -52,11 +47,8 @@ class ModelProviderCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
# if credential_id is not provided, return current used credential
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
|
|
@ -73,8 +65,7 @@ class ModelProviderCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -85,11 +76,9 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
try:
|
||||
model_provider_service.create_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_name=args["name"],
|
||||
|
|
@ -103,8 +92,7 @@ class ModelProviderCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -116,11 +104,9 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
try:
|
||||
model_provider_service.update_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_id=args["credential_id"],
|
||||
|
|
@ -135,19 +121,16 @@ class ModelProviderCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||
tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
|
@ -159,19 +142,16 @@ class ModelProviderCredentialSwitchApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
service = ModelProviderService()
|
||||
service.switch_active_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
|
|
@ -184,15 +164,12 @@ class ModelProviderValidateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant_id = current_user.current_tenant_id
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -240,14 +217,11 @@ class PreferredProviderTypeUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant_id = current_user.current_tenant_id
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
|
|
@ -276,14 +250,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
|||
def get(self, provider: str):
|
||||
if provider != "anthropic":
|
||||
raise ValueError(f"provider name {provider} is invalid")
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
data = BillingService.get_model_provider_payment_link(
|
||||
provider_name=provider,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
account_id=current_user.id,
|
||||
prefilled_email=current_user.email,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -10,7 +9,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
|
@ -23,6 +22,8 @@ class DefaultModelApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"model_type",
|
||||
|
|
@ -34,8 +35,6 @@ class DefaultModelApi(Resource):
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
||||
tenant_id=tenant_id, model_type=args["model_type"]
|
||||
|
|
@ -47,15 +46,14 @@ class DefaultModelApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_settings = args["model_settings"]
|
||||
for model_setting in model_settings:
|
||||
|
|
@ -92,7 +90,7 @@ class ModelProviderModelApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
|
@ -104,11 +102,11 @@ class ModelProviderModelApi(Resource):
|
|||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
# To save the model's load balance configs
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
|
|
@ -129,7 +127,7 @@ class ModelProviderModelApi(Resource):
|
|||
raise ValueError("credential_id is required when configuring a custom-model")
|
||||
service = ModelProviderService()
|
||||
service.switch_active_custom_model_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
|
|
@ -164,11 +162,11 @@ class ModelProviderModelApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
|
|
@ -195,7 +193,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||
|
|
@ -257,6 +255,8 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -274,7 +274,6 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
|
|
@ -301,6 +300,8 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -323,7 +324,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
|
||||
try:
|
||||
model_provider_service.update_model_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
|
|
@ -340,6 +341,8 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -357,7 +360,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
|
|
@ -373,6 +376,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -390,7 +395,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
|||
|
||||
service = ModelProviderService()
|
||||
service.add_model_credential_to_model_list(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
|
|
@ -407,7 +412,7 @@ class ModelProviderModelEnableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
|
|
@ -437,7 +442,7 @@ class ModelProviderModelDisableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
|
|
@ -465,7 +470,7 @@ class ModelProviderModelValidateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
|
|
@ -514,8 +519,7 @@ class ModelProviderModelParameterRuleApi(Resource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||
|
|
@ -531,8 +535,7 @@ class ModelProviderAvailableModelApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, model_type):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import io
|
||||
|
||||
from flask import request, send_file
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -11,7 +10,7 @@ from controllers.console.workspace import plugin_permission_required
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_parameter_service import PluginParameterService
|
||||
|
|
@ -26,7 +25,7 @@ class PluginDebuggingKeyApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {
|
||||
|
|
@ -44,7 +43,7 @@ class PluginListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=False, location="args", default=1)
|
||||
parser.add_argument("page_size", type=int, required=False, location="args", default=256)
|
||||
|
|
@ -81,7 +80,7 @@ class PluginListInstallationsFromIdsApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
|
|
@ -120,7 +119,7 @@ class PluginUploadFromPkgApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
file = request.files["pkg"]
|
||||
|
||||
|
|
@ -144,7 +143,7 @@ class PluginUploadFromGithubApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("repo", type=str, required=True, location="json")
|
||||
|
|
@ -167,7 +166,7 @@ class PluginUploadFromBundleApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
file = request.files["bundle"]
|
||||
|
||||
|
|
@ -191,7 +190,7 @@ class PluginInstallFromPkgApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||
|
|
@ -217,7 +216,7 @@ class PluginInstallFromGithubApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("repo", type=str, required=True, location="json")
|
||||
|
|
@ -247,7 +246,7 @@ class PluginInstallFromMarketplaceApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||
|
|
@ -273,7 +272,7 @@ class PluginFetchMarketplacePkgApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
|
|
@ -299,7 +298,7 @@ class PluginFetchManifestApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
|
|
@ -324,7 +323,7 @@ class PluginFetchInstallTasksApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -346,7 +345,7 @@ class PluginFetchInstallTaskApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self, task_id: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
|
||||
|
|
@ -361,7 +360,7 @@ class PluginDeleteInstallTaskApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self, task_id: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
|
||||
|
|
@ -376,7 +375,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
|
||||
|
|
@ -391,7 +390,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self, task_id: str, identifier: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
|
||||
|
|
@ -406,7 +405,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
|
|
@ -430,7 +429,7 @@ class PluginUpgradeFromGithubApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
|
|
@ -466,7 +465,7 @@ class PluginUninstallApi(Resource):
|
|||
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
|
||||
|
|
@ -480,6 +479,7 @@ class PluginChangePermissionApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
|
@ -492,7 +492,7 @@ class PluginChangePermissionApi(Resource):
|
|||
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
||||
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
||||
|
||||
|
|
@ -503,7 +503,7 @@ class PluginFetchPermissionApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
permission = PluginPermissionService.get_permission(tenant_id)
|
||||
if not permission:
|
||||
|
|
@ -529,10 +529,10 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
# check if the user is admin or owner
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -565,7 +565,7 @@ class PluginChangePreferencesApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -574,8 +574,6 @@ class PluginChangePreferencesApi(Resource):
|
|||
req.add_argument("auto_upgrade", type=dict, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
permission = args["permission"]
|
||||
|
||||
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
|
||||
|
|
@ -621,7 +619,7 @@ class PluginFetchPreferencesApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
permission = PluginPermissionService.get_permission(tenant_id)
|
||||
permission_dict = {
|
||||
|
|
@ -661,7 +659,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
|||
@account_initialization_required
|
||||
def post(self):
|
||||
# exclude one single plugin
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("plugin_id", type=str, required=True, location="json")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import io
|
|||
from urllib.parse import urlparse
|
||||
|
||||
from flask import make_response, redirect, request, send_file
|
||||
from flask_login import current_user
|
||||
from flask_restx import (
|
||||
Resource,
|
||||
reqparse,
|
||||
|
|
@ -26,7 +25,7 @@ from core.plugin.impl.oauth import OAuthHandler
|
|||
from core.tools.entities.tool_entities import CredentialType
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
|
|
@ -55,10 +54,9 @@ class ToolProviderListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument(
|
||||
|
|
@ -80,9 +78,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user = current_user
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.list_builtin_tool_provider_tools(
|
||||
|
|
@ -98,9 +94,7 @@ class ToolBuiltinProviderInfoApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user = current_user
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||
|
||||
|
|
@ -111,11 +105,10 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = req.parse_args()
|
||||
|
|
@ -133,10 +126,9 @@ class ToolBuiltinProviderAddApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
|
|
@ -163,13 +155,12 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
|
|
@ -195,7 +186,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||
|
|
@ -220,13 +211,12 @@ class ToolApiProviderAddApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
|
|
@ -260,10 +250,9 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
|
|
@ -284,10 +273,9 @@ class ToolApiProviderListToolsApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
|
|
@ -310,13 +298,12 @@ class ToolApiProviderUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
|
|
@ -352,13 +339,12 @@ class ToolApiProviderDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
|
|
@ -379,10 +365,9 @@ class ToolApiProviderGetApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
|
|
@ -403,8 +388,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, credential_type):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.list_builtin_provider_credentials_schema(
|
||||
|
|
@ -446,9 +430,9 @@ class ToolApiProviderPreviousTestApi(Resource):
|
|||
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return ApiToolManageService.test_api_tool_preview(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
args["provider_name"] or "",
|
||||
args["tool_name"],
|
||||
args["credentials"],
|
||||
|
|
@ -464,13 +448,12 @@ class ToolWorkflowProviderCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
|
|
@ -504,13 +487,12 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
|
|
@ -547,13 +529,12 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
|
|
@ -573,10 +554,9 @@ class ToolWorkflowProviderGetApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
|
|
@ -608,10 +588,9 @@ class ToolWorkflowProviderListToolApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
|
||||
|
|
@ -633,10 +612,9 @@ class ToolBuiltinListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
|
|
@ -655,8 +633,7 @@ class ToolApiListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
|
|
@ -674,10 +651,9 @@ class ToolWorkflowListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
|
|
@ -711,19 +687,18 @@ class ToolPluginOAuthApi(Resource):
|
|||
provider_name = tool_provider.provider_name
|
||||
|
||||
# todo check permission
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("no oauth available client config found for this tool provider")
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
|
||||
user_id=user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
|
||||
)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
|
|
@ -802,11 +777,12 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -821,13 +797,13 @@ class ToolOAuthCustomClient(Resource):
|
|||
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return BuiltinToolManageService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
client_params=args.get("client_params", {}),
|
||||
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
|
||||
|
|
@ -837,20 +813,18 @@ class ToolOAuthCustomClient(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_custom_oauth_client_params(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider
|
||||
)
|
||||
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.delete_custom_oauth_client_params(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider
|
||||
)
|
||||
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -860,9 +834,10 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
|
||||
tenant_id=current_user.current_tenant_id, provider_name=provider
|
||||
tenant_id=current_tenant_id, provider_name=provider
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -873,7 +848,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
|
||||
|
|
@ -900,6 +875,7 @@ class ToolProviderMCPApi(Resource):
|
|||
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
parser.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
args = parser.parse_args()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
# Validate server URL
|
||||
if not is_valid_url(args["server_url"]):
|
||||
|
|
@ -913,8 +889,8 @@ class ToolProviderMCPApi(Resource):
|
|||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
|
|
@ -954,10 +930,12 @@ class ToolProviderMCPApi(Resource):
|
|||
raise ValueError("Server URL is not valid.")
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=args["provider_id"],
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
|
|
@ -982,9 +960,11 @@ class ToolProviderMCPApi(Resource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
service.delete_provider(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
|
||||
service.delete_provider(tenant_id=current_tenant_id.current_tenant_id, provider_id=args["provider_id"])
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
@ -1000,7 +980,7 @@ class ToolMCPAuthApi(Resource):
|
|||
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
provider_id = args["provider_id"]
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
|
|
@ -1049,10 +1029,10 @@ class ToolMCPDetailApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
user = current_user
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
provider = service.get_provider(provider_id=provider_id, tenant_id=user.current_tenant_id)
|
||||
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
|
||||
|
||||
|
|
@ -1062,8 +1042,7 @@ class ToolMCPListAllApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
|
|
@ -1078,7 +1057,7 @@ class ToolMCPUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_provider_tools(
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from configs import dify_config
|
|||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_user
|
||||
from models.account import Account, AccountStatus
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.account import AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
|
@ -25,16 +25,13 @@ P = ParamSpec("P")
|
|||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _current_account() -> Account:
|
||||
assert isinstance(current_user, Account)
|
||||
return current_user
|
||||
|
||||
|
||||
def account_initialization_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# check account initialization
|
||||
account = _current_account()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
|
||||
if account.status == AccountStatus.UNINITIALIZED:
|
||||
raise AccountNotInitializedError()
|
||||
|
|
@ -80,9 +77,8 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
|||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if not features.billing.enabled:
|
||||
abort(403, "Billing feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
|
@ -94,10 +90,8 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
|
|
@ -138,9 +132,8 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
|||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
|
|
@ -163,13 +156,11 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if resource == "knowledge":
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{tenant_id}"
|
||||
key = f"rate_limit_{current_tenant_id}"
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
|
||||
|
|
@ -180,7 +171,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||
if request_count > knowledge_rate_limit.limit:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
|
|
@ -200,17 +191,15 @@ def cloud_utm_record(view: Callable[P, R]):
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: dict = json.loads(utm_info)
|
||||
OperationService.record_utm(tenant_id, utm_info_dict)
|
||||
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
@ -289,9 +278,8 @@ def enable_change_email(view: Callable[P, R]):
|
|||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.is_allow_transfer_workspace:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
@ -301,12 +289,11 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
|||
return decorated
|
||||
|
||||
|
||||
def knowledge_pipeline_publish_enabled(view):
|
||||
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.knowledge_pipeline.publish_enabled:
|
||||
return view(*args, **kwargs)
|
||||
abort(403)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
|
@ -16,6 +15,7 @@ from core.model_manager import ModelManager
|
|||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
|
||||
|
|
@ -66,6 +66,7 @@ class SegmentApi(DatasetApiResource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Create single segment."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
|
@ -84,7 +85,7 @@ class SegmentApi(DatasetApiResource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -117,6 +118,7 @@ class SegmentApi(DatasetApiResource):
|
|||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Get segments."""
|
||||
# check dataset
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
|
|
@ -133,7 +135,7 @@ class SegmentApi(DatasetApiResource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -149,7 +151,7 @@ class SegmentApi(DatasetApiResource):
|
|||
|
||||
segments, total = SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
status_list=args["status"],
|
||||
keyword=args["keyword"],
|
||||
page=page,
|
||||
|
|
@ -184,6 +186,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
|
|
@ -195,7 +198,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
|
|
@ -217,6 +220,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
|
|
@ -232,7 +236,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -244,7 +248,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
|
|
@ -266,6 +270,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
|
|
@ -277,7 +282,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
|
|
@ -307,6 +312,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Create child chunk."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
|
@ -319,7 +325,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||
raise NotFound("Document not found.")
|
||||
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
|
|
@ -328,7 +334,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -364,6 +370,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Get child chunks."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
|
@ -376,7 +383,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||
raise NotFound("Document not found.")
|
||||
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
|
|
@ -423,6 +430,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Delete child chunk."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
|
@ -435,7 +443,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
raise NotFound("Document not found.")
|
||||
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
|
|
@ -444,9 +452,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
raise NotFound("Document not found.")
|
||||
|
||||
# check child chunk
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(
|
||||
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
|
|
@ -483,6 +489,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Update child chunk."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
|
@ -495,7 +502,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
raise NotFound("Document not found.")
|
||||
|
||||
# get segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
|
|
@ -504,9 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
raise NotFound("Segment not found.")
|
||||
|
||||
# get child chunk
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(
|
||||
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from cachetools import TTLCache, cachedmethod
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
|
|
@ -45,6 +47,8 @@ class AppQueueManager:
|
|||
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
|
||||
self._q = q
|
||||
self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1)
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
def listen(self):
|
||||
"""
|
||||
|
|
@ -157,6 +161,7 @@ class AppQueueManager:
|
|||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
@cachedmethod(lambda self: self._stopped_cache, lock=lambda self: self._cache_lock)
|
||||
def _is_stopped(self) -> bool:
|
||||
"""
|
||||
Check if task is stopped
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import inspect
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -31,6 +31,17 @@ from core.plugin.impl.exc import (
|
|||
)
|
||||
|
||||
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
||||
_plugin_daemon_timeout_config = cast(
|
||||
float | httpx.Timeout | None,
|
||||
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
|
||||
)
|
||||
plugin_daemon_request_timeout: httpx.Timeout | None
|
||||
if _plugin_daemon_timeout_config is None:
|
||||
plugin_daemon_request_timeout = None
|
||||
elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
|
||||
plugin_daemon_request_timeout = _plugin_daemon_timeout_config
|
||||
else:
|
||||
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
|
|
@ -58,6 +69,7 @@ class BasePluginClient:
|
|||
"headers": headers,
|
||||
"params": params,
|
||||
"files": files,
|
||||
"timeout": plugin_daemon_request_timeout,
|
||||
}
|
||||
if isinstance(prepared_data, dict):
|
||||
request_kwargs["data"] = prepared_data
|
||||
|
|
@ -116,6 +128,7 @@ class BasePluginClient:
|
|||
"headers": headers,
|
||||
"params": params,
|
||||
"files": files,
|
||||
"timeout": plugin_daemon_request_timeout,
|
||||
}
|
||||
if isinstance(prepared_data, dict):
|
||||
stream_kwargs["data"] = prepared_data
|
||||
|
|
|
|||
|
|
@ -43,8 +43,7 @@ class CacheEmbedding(Embeddings):
|
|||
else:
|
||||
embedding_queue_indices.append(i)
|
||||
|
||||
# release database connection, because embedding may take a long time
|
||||
db.session.close()
|
||||
# NOTE: avoid closing the shared scoped session here; downstream code may still have pending work
|
||||
|
||||
if embedding_queue_indices:
|
||||
embedding_queue_texts = [texts[i] for i in embedding_queue_indices]
|
||||
|
|
|
|||
|
|
@ -13,6 +13,15 @@ from models.model import EndUser
|
|||
#: A proxy for the current user. If no user is logged in, this will be an
|
||||
#: anonymous user
|
||||
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
|
||||
|
||||
|
||||
def current_account_with_tenant():
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
||||
return current_user, current_user.current_tenant_id
|
||||
|
||||
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
|
|
|||
|
|
@ -8,8 +8,7 @@ from werkzeug.exceptions import NotFound
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
|
||||
|
|
@ -24,10 +23,10 @@ class AppAnnotationService:
|
|||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -63,12 +62,12 @@ class AppAnnotationService:
|
|||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
assert current_user.current_tenant_id is not None
|
||||
assert current_tenant_id is not None
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
args["question"],
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
)
|
||||
|
|
@ -86,13 +85,12 @@ class AppAnnotationService:
|
|||
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
|
||||
# send batch add segments task
|
||||
redis_client.setnx(enable_app_annotation_job_key, "waiting")
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
enable_annotation_reply_task.delay(
|
||||
str(job_id),
|
||||
app_id,
|
||||
current_user.id,
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
args["score_threshold"],
|
||||
args["embedding_provider_name"],
|
||||
args["embedding_model_name"],
|
||||
|
|
@ -101,8 +99,7 @@ class AppAnnotationService:
|
|||
|
||||
@classmethod
|
||||
def disable_app_annotation(cls, app_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(disable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
|
|
@ -113,17 +110,16 @@ class AppAnnotationService:
|
|||
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
|
||||
# send batch add segments task
|
||||
redis_client.setnx(disable_app_annotation_job_key, "waiting")
|
||||
disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id)
|
||||
disable_annotation_reply_task.delay(str(job_id), app_id, current_tenant_id)
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -153,11 +149,10 @@ class AppAnnotationService:
|
|||
@classmethod
|
||||
def export_annotation_list_by_app_id(cls, app_id: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -174,11 +169,10 @@ class AppAnnotationService:
|
|||
@classmethod
|
||||
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -196,7 +190,7 @@ class AppAnnotationService:
|
|||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
args["question"],
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
)
|
||||
|
|
@ -205,11 +199,10 @@ class AppAnnotationService:
|
|||
@classmethod
|
||||
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -234,7 +227,7 @@ class AppAnnotationService:
|
|||
update_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
annotation.question,
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
app_id,
|
||||
app_annotation_setting.collection_binding_id,
|
||||
)
|
||||
|
|
@ -244,11 +237,10 @@ class AppAnnotationService:
|
|||
@classmethod
|
||||
def delete_app_annotation(cls, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -277,17 +269,16 @@ class AppAnnotationService:
|
|||
|
||||
if app_annotation_setting:
|
||||
delete_annotation_index_task.delay(
|
||||
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
|
||||
annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -317,7 +308,7 @@ class AppAnnotationService:
|
|||
for annotation, annotation_setting in annotations_to_delete:
|
||||
if annotation_setting:
|
||||
delete_annotation_index_task.delay(
|
||||
annotation.id, app_id, current_user.current_tenant_id, annotation_setting.collection_binding_id
|
||||
annotation.id, app_id, current_tenant_id, annotation_setting.collection_binding_id
|
||||
)
|
||||
|
||||
# Step 4: Bulk delete annotations in a single query
|
||||
|
|
@ -333,11 +324,10 @@ class AppAnnotationService:
|
|||
@classmethod
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -354,7 +344,7 @@ class AppAnnotationService:
|
|||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
# check annotation limit
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
annotation_quota_limit = features.annotation_quota_limit
|
||||
if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size:
|
||||
|
|
@ -364,21 +354,18 @@ class AppAnnotationService:
|
|||
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, "waiting")
|
||||
batch_import_annotations_task.delay(
|
||||
str(job_id), result, app_id, current_user.current_tenant_id, current_user.id
|
||||
)
|
||||
batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id)
|
||||
except Exception as e:
|
||||
return {"error_msg": str(e)}
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -445,12 +432,11 @@ class AppAnnotationService:
|
|||
|
||||
@classmethod
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -481,12 +467,11 @@ class AppAnnotationService:
|
|||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -531,11 +516,10 @@ class AppAnnotationService:
|
|||
|
||||
@classmethod
|
||||
def clear_all_annotations(cls, app_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -558,7 +542,7 @@ class AppAnnotationService:
|
|||
# if annotation reply is enabled, delete annotation index
|
||||
if app_annotation_setting:
|
||||
delete_annotation_index_task.delay(
|
||||
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
|
||||
annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id
|
||||
)
|
||||
|
||||
db.session.delete(annotation)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import time
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -18,6 +17,7 @@ from core.tools.entities.tool_entities import CredentialType
|
|||
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
|
@ -93,6 +93,8 @@ class DatasourceProviderService:
|
|||
"""
|
||||
get credential by id
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
if credential_id:
|
||||
datasource_provider = (
|
||||
|
|
@ -157,6 +159,8 @@ class DatasourceProviderService:
|
|||
"""
|
||||
get all datasource credentials by provider
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
datasource_providers = (
|
||||
session.query(DatasourceProvider)
|
||||
|
|
@ -604,6 +608,8 @@ class DatasourceProviderService:
|
|||
"""
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
|
|
@ -901,6 +907,8 @@ class DatasourceProviderService:
|
|||
"""
|
||||
update datasource credentials.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import click
|
|||
import pandas as pd
|
||||
from celery import shared_task
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
|
@ -50,54 +49,48 @@ def batch_create_segment_to_index_task(
|
|||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not exist.")
|
||||
dataset = db.session.get(Dataset, dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not exist.")
|
||||
|
||||
dataset_document = session.get(Document, document_id)
|
||||
if not dataset_document:
|
||||
raise ValueError("Document not exist.")
|
||||
dataset_document = db.session.get(Document, document_id)
|
||||
if not dataset_document:
|
||||
raise ValueError("Document not exist.")
|
||||
|
||||
if (
|
||||
not dataset_document.enabled
|
||||
or dataset_document.archived
|
||||
or dataset_document.indexing_status != "completed"
|
||||
):
|
||||
raise ValueError("Document is not available.")
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
raise ValueError("Document is not available.")
|
||||
|
||||
upload_file = session.get(UploadFile, upload_file_id)
|
||||
if not upload_file:
|
||||
raise ValueError("UploadFile not found.")
|
||||
upload_file = db.session.get(UploadFile, upload_file_id)
|
||||
if not upload_file:
|
||||
raise ValueError("UploadFile not found.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
||||
|
|
@ -105,6 +98,7 @@ def batch_create_segment_to_index_task(
|
|||
)
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
|
@ -135,11 +129,11 @@ def batch_create_segment_to_index_task(
|
|||
word_count_change += segment_document.word_count
|
||||
db.session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
# update document word count
|
||||
|
||||
assert dataset_document.word_count is not None
|
||||
dataset_document.word_count += word_count_change
|
||||
db.session.add(dataset_document)
|
||||
# add index to db
|
||||
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
||||
db.session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
|
|
|
|||
|
|
@ -25,9 +25,7 @@ class TestAnnotationService:
|
|||
patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task,
|
||||
patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task,
|
||||
patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task,
|
||||
patch(
|
||||
"services.annotation_service.current_user", create_autospec(Account, instance=True)
|
||||
) as mock_current_user,
|
||||
patch("services.annotation_service.current_account_with_tenant") as mock_current_account_with_tenant,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_account_feature_service.get_features.return_value.billing.enabled = False
|
||||
|
|
@ -38,6 +36,9 @@ class TestAnnotationService:
|
|||
mock_disable_task.delay.return_value = None
|
||||
mock_batch_import_task.delay.return_value = None
|
||||
|
||||
# Create mock user that will be returned by current_account_with_tenant
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
|
||||
yield {
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"feature_service": mock_feature_service,
|
||||
|
|
@ -47,7 +48,8 @@ class TestAnnotationService:
|
|||
"enable_task": mock_enable_task,
|
||||
"disable_task": mock_disable_task,
|
||||
"batch_import_task": mock_batch_import_task,
|
||||
"current_user": mock_current_user,
|
||||
"current_account_with_tenant": mock_current_account_with_tenant,
|
||||
"current_user": mock_user,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
|
|
@ -107,6 +109,11 @@ class TestAnnotationService:
|
|||
"""
|
||||
mock_external_service_dependencies["current_user"].id = account_id
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id
|
||||
# Configure current_account_with_tenant to return (user, tenant_id)
|
||||
mock_external_service_dependencies["current_account_with_tenant"].return_value = (
|
||||
mock_external_service_dependencies["current_user"],
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
def _create_test_conversation(self, app, account, fake):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class TestAccountInitialization:
|
|||
return "success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps._current_account", return_value=mock_user):
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
|
|
@ -77,7 +77,7 @@ class TestAccountInitialization:
|
|||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with patch("controllers.console.wraps._current_account", return_value=mock_user):
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_user, "tenant123")):
|
||||
with pytest.raises(AccountNotInitializedError):
|
||||
protected_view()
|
||||
|
||||
|
|
@ -163,7 +163,9 @@ class TestBillingResourceLimits:
|
|||
return "member_added"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch(
|
||||
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
|
||||
):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
result = add_member()
|
||||
|
||||
|
|
@ -185,7 +187,10 @@ class TestBillingResourceLimits:
|
|||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch(
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
return_value=(MockUser("test_user"), "tenant123"),
|
||||
):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
add_member()
|
||||
|
|
@ -207,7 +212,10 @@ class TestBillingResourceLimits:
|
|||
|
||||
# Test 1: Should reject when source is datasets
|
||||
with app.test_request_context("/?source=datasets"):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch(
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
return_value=(MockUser("test_user"), "tenant123"),
|
||||
):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
upload_document()
|
||||
|
|
@ -215,7 +223,10 @@ class TestBillingResourceLimits:
|
|||
|
||||
# Test 2: Should allow when source is not datasets
|
||||
with app.test_request_context("/?source=other"):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch(
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
return_value=(MockUser("test_user"), "tenant123"),
|
||||
):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
result = upload_document()
|
||||
assert result == "document_uploaded"
|
||||
|
|
@ -239,7 +250,9 @@ class TestRateLimiting:
|
|||
return "knowledge_success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch(
|
||||
"controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("test_user"), "tenant123")
|
||||
):
|
||||
with patch(
|
||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||
):
|
||||
|
|
@ -271,7 +284,10 @@ class TestRateLimiting:
|
|||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch(
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
return_value=(MockUser("test_user"), "tenant123"),
|
||||
):
|
||||
with patch(
|
||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||
):
|
||||
|
|
|
|||
|
|
@ -53,9 +53,6 @@ const ChatWrapper = () => {
|
|||
initUserVariables,
|
||||
} = useChatWithHistoryContext()
|
||||
|
||||
// Semantic variable for better code readability
|
||||
const isHistoryConversation = !!currentConversationId
|
||||
|
||||
const appConfig = useMemo(() => {
|
||||
const config = appParams || {}
|
||||
|
||||
|
|
@ -66,9 +63,9 @@ const ChatWrapper = () => {
|
|||
fileUploadConfig: (config as any).system_parameters,
|
||||
},
|
||||
supportFeedback: true,
|
||||
opening_statement: isHistoryConversation ? currentConversationItem?.introduction : (config as any).opening_statement,
|
||||
opening_statement: currentConversationItem?.introduction || (config as any).opening_statement,
|
||||
} as ChatConfig
|
||||
}, [appParams, currentConversationItem?.introduction, isHistoryConversation])
|
||||
}, [appParams, currentConversationItem?.introduction])
|
||||
const {
|
||||
chatList,
|
||||
setTargetMessageId,
|
||||
|
|
@ -79,7 +76,7 @@ const ChatWrapper = () => {
|
|||
} = useChat(
|
||||
appConfig,
|
||||
{
|
||||
inputs: (isHistoryConversation ? currentConversationInputs : newConversationInputs) as any,
|
||||
inputs: (currentConversationId ? currentConversationInputs : newConversationInputs) as any,
|
||||
inputsForm: inputsForms,
|
||||
},
|
||||
appPrevChatTree,
|
||||
|
|
@ -87,7 +84,7 @@ const ChatWrapper = () => {
|
|||
clearChatList,
|
||||
setClearChatList,
|
||||
)
|
||||
const inputsFormValue = isHistoryConversation ? currentConversationInputs : newConversationInputsRef?.current
|
||||
const inputsFormValue = currentConversationId ? currentConversationInputs : newConversationInputsRef?.current
|
||||
const inputDisabled = useMemo(() => {
|
||||
if (allInputsHidden)
|
||||
return false
|
||||
|
|
@ -136,7 +133,7 @@ const ChatWrapper = () => {
|
|||
const data: any = {
|
||||
query: message,
|
||||
files,
|
||||
inputs: formatBooleanInputs(inputsForms, isHistoryConversation ? currentConversationInputs : newConversationInputs),
|
||||
inputs: formatBooleanInputs(inputsForms, currentConversationId ? currentConversationInputs : newConversationInputs),
|
||||
conversation_id: currentConversationId,
|
||||
parent_message_id: (isRegenerate ? parentAnswer?.id : getLastAnswer(chatList)?.id) || null,
|
||||
}
|
||||
|
|
@ -146,11 +143,11 @@ const ChatWrapper = () => {
|
|||
data,
|
||||
{
|
||||
onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, isInstalledApp, appId),
|
||||
onConversationComplete: isHistoryConversation ? undefined : handleNewConversationCompleted,
|
||||
onConversationComplete: currentConversationId ? undefined : handleNewConversationCompleted,
|
||||
isPublicAPI: !isInstalledApp,
|
||||
},
|
||||
)
|
||||
}, [chatList, handleNewConversationCompleted, handleSend, isHistoryConversation, currentConversationInputs, newConversationInputs, isInstalledApp, appId])
|
||||
}, [chatList, handleNewConversationCompleted, handleSend, currentConversationId, currentConversationInputs, newConversationInputs, isInstalledApp, appId])
|
||||
|
||||
const doRegenerate = useCallback((chatItem: ChatItemInTree, editedQuestion?: { message: string, files?: FileEntity[] }) => {
|
||||
const question = editedQuestion ? chatItem : chatList.find(item => item.id === chatItem.parentMessageId)!
|
||||
|
|
@ -163,30 +160,38 @@ const ChatWrapper = () => {
|
|||
}, [chatList, doSend])
|
||||
|
||||
const messageList = useMemo(() => {
|
||||
// Always filter out opening statement from message list as it's handled separately in welcome component
|
||||
if (currentConversationId || chatList.length > 1)
|
||||
return chatList
|
||||
// Without messages we are in the welcome screen, so hide the opening statement from chatlist
|
||||
return chatList.filter(item => !item.isOpeningStatement)
|
||||
}, [chatList])
|
||||
|
||||
const [collapsed, setCollapsed] = useState(isHistoryConversation)
|
||||
const [collapsed, setCollapsed] = useState(!!currentConversationId)
|
||||
|
||||
const chatNode = useMemo(() => {
|
||||
if (allInputsHidden || !inputsForms.length)
|
||||
return null
|
||||
if (isMobile) {
|
||||
if (!isHistoryConversation)
|
||||
if (!currentConversationId)
|
||||
return <InputsForm collapsed={collapsed} setCollapsed={setCollapsed} />
|
||||
return null
|
||||
}
|
||||
else {
|
||||
return <InputsForm collapsed={collapsed} setCollapsed={setCollapsed} />
|
||||
}
|
||||
}, [inputsForms.length, isMobile, isHistoryConversation, collapsed, allInputsHidden])
|
||||
},
|
||||
[
|
||||
inputsForms.length,
|
||||
isMobile,
|
||||
currentConversationId,
|
||||
collapsed, allInputsHidden,
|
||||
])
|
||||
|
||||
const welcome = useMemo(() => {
|
||||
const welcomeMessage = chatList.find(item => item.isOpeningStatement)
|
||||
if (respondingState)
|
||||
return null
|
||||
if (isHistoryConversation)
|
||||
if (currentConversationId)
|
||||
return null
|
||||
if (!welcomeMessage)
|
||||
return null
|
||||
|
|
@ -227,7 +232,18 @@ const ChatWrapper = () => {
|
|||
</div>
|
||||
</div>
|
||||
)
|
||||
}, [appData?.site.icon, appData?.site.icon_background, appData?.site.icon_type, appData?.site.icon_url, chatList, collapsed, isHistoryConversation, inputsForms.length, respondingState, allInputsHidden])
|
||||
},
|
||||
[
|
||||
appData?.site.icon,
|
||||
appData?.site.icon_background,
|
||||
appData?.site.icon_type,
|
||||
appData?.site.icon_url,
|
||||
chatList, collapsed,
|
||||
currentConversationId,
|
||||
inputsForms.length,
|
||||
respondingState,
|
||||
allInputsHidden,
|
||||
])
|
||||
|
||||
const answerIcon = (appData?.site && appData.site.use_icon_as_answer_icon)
|
||||
? <AnswerIcon
|
||||
|
|
@ -251,7 +267,7 @@ const ChatWrapper = () => {
|
|||
chatFooterClassName='pb-4'
|
||||
chatFooterInnerClassName={`mx-auto w-full max-w-[768px] ${isMobile ? 'px-2' : 'px-4'}`}
|
||||
onSend={doSend}
|
||||
inputs={isHistoryConversation ? currentConversationInputs as any : newConversationInputs}
|
||||
inputs={currentConversationId ? currentConversationInputs as any : newConversationInputs}
|
||||
inputsForm={inputsForms}
|
||||
onRegenerate={doRegenerate}
|
||||
onStopResponding={handleStop}
|
||||
|
|
|
|||
|
|
@ -111,6 +111,8 @@ const Answer: FC<AnswerProps> = ({
|
|||
}
|
||||
}, [switchSibling, item.prevSibling, item.nextSibling])
|
||||
|
||||
const contentIsEmpty = content.trim() === ''
|
||||
|
||||
return (
|
||||
<div className='mb-2 flex last:mb-0'>
|
||||
<div className='relative h-10 w-10 shrink-0'>
|
||||
|
|
@ -153,14 +155,14 @@ const Answer: FC<AnswerProps> = ({
|
|||
)
|
||||
}
|
||||
{
|
||||
responding && !content && !hasAgentThoughts && (
|
||||
responding && contentIsEmpty && !hasAgentThoughts && (
|
||||
<div className='flex h-5 w-6 items-center justify-center'>
|
||||
<LoadingAnim type='text' />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
content && !hasAgentThoughts && (
|
||||
!contentIsEmpty && !hasAgentThoughts && (
|
||||
<BasicContent item={item} />
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -83,6 +83,15 @@ const ChatInputArea = ({
|
|||
const historyRef = useRef([''])
|
||||
const [currentIndex, setCurrentIndex] = useState(-1)
|
||||
const isComposingRef = useRef(false)
|
||||
|
||||
const handleQueryChange = useCallback(
|
||||
(value: string) => {
|
||||
setQuery(value)
|
||||
setTimeout(handleTextareaResize, 0)
|
||||
},
|
||||
[handleTextareaResize],
|
||||
)
|
||||
|
||||
const handleSend = () => {
|
||||
if (isResponding) {
|
||||
notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
|
||||
|
|
@ -101,7 +110,7 @@ const ChatInputArea = ({
|
|||
}
|
||||
if (checkInputsForm(inputs, inputsForm)) {
|
||||
onSend(query, files)
|
||||
setQuery('')
|
||||
handleQueryChange('')
|
||||
setFiles([])
|
||||
}
|
||||
}
|
||||
|
|
@ -131,19 +140,19 @@ const ChatInputArea = ({
|
|||
// When the cmd + up key is pressed, output the previous element
|
||||
if (currentIndex > 0) {
|
||||
setCurrentIndex(currentIndex - 1)
|
||||
setQuery(historyRef.current[currentIndex - 1])
|
||||
handleQueryChange(historyRef.current[currentIndex - 1])
|
||||
}
|
||||
}
|
||||
else if (e.key === 'ArrowDown' && !e.shiftKey && !e.nativeEvent.isComposing && e.metaKey) {
|
||||
// When the cmd + down key is pressed, output the next element
|
||||
if (currentIndex < historyRef.current.length - 1) {
|
||||
setCurrentIndex(currentIndex + 1)
|
||||
setQuery(historyRef.current[currentIndex + 1])
|
||||
handleQueryChange(historyRef.current[currentIndex + 1])
|
||||
}
|
||||
else if (currentIndex === historyRef.current.length - 1) {
|
||||
// If it is the last element, clear the input box
|
||||
setCurrentIndex(historyRef.current.length)
|
||||
setQuery('')
|
||||
handleQueryChange('')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -171,7 +180,7 @@ const ChatInputArea = ({
|
|||
<>
|
||||
<div
|
||||
className={cn(
|
||||
'relative z-10 rounded-xl border border-components-chat-input-border bg-components-panel-bg-blur pb-[9px] shadow-md',
|
||||
'relative z-10 overflow-hidden rounded-xl border border-components-chat-input-border bg-components-panel-bg-blur pb-[9px] shadow-md',
|
||||
isDragActive && 'border border-dashed border-components-option-card-option-selected-border',
|
||||
disabled && 'pointer-events-none border-components-panel-border opacity-50 shadow-none',
|
||||
)}
|
||||
|
|
@ -197,12 +206,8 @@ const ChatInputArea = ({
|
|||
placeholder={t('common.chat.inputPlaceholder', { botName }) || ''}
|
||||
autoFocus
|
||||
minRows={1}
|
||||
onResize={handleTextareaResize}
|
||||
value={query}
|
||||
onChange={(e) => {
|
||||
setQuery(e.target.value)
|
||||
setTimeout(handleTextareaResize, 0)
|
||||
}}
|
||||
onChange={e => handleQueryChange(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
onCompositionStart={handleCompositionStart}
|
||||
onCompositionEnd={handleCompositionEnd}
|
||||
|
|
@ -221,7 +226,7 @@ const ChatInputArea = ({
|
|||
showVoiceInput && (
|
||||
<VoiceInput
|
||||
onCancel={() => setShowVoiceInput(false)}
|
||||
onConverted={text => setQuery(text)}
|
||||
onConverted={text => handleQueryChange(text)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import type { FC, Ref } from 'react'
|
||||
import { memo } from 'react'
|
||||
import {
|
||||
RiMicLine,
|
||||
|
|
@ -18,20 +19,17 @@ type OperationProps = {
|
|||
speechToTextConfig?: EnableType
|
||||
onShowVoiceInput?: () => void
|
||||
onSend: () => void
|
||||
theme?: Theme | null
|
||||
theme?: Theme | null,
|
||||
ref?: Ref<HTMLDivElement>;
|
||||
}
|
||||
const Operation = (
|
||||
{
|
||||
ref,
|
||||
fileConfig,
|
||||
speechToTextConfig,
|
||||
onShowVoiceInput,
|
||||
onSend,
|
||||
theme,
|
||||
}: OperationProps & {
|
||||
ref: React.RefObject<HTMLDivElement>;
|
||||
},
|
||||
) => {
|
||||
const Operation: FC<OperationProps> = ({
|
||||
ref,
|
||||
fileConfig,
|
||||
speechToTextConfig,
|
||||
onShowVoiceInput,
|
||||
onSend,
|
||||
theme,
|
||||
}) => {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
|
|
|
|||
|
|
@ -62,9 +62,9 @@ const ChatWrapper = () => {
|
|||
fileUploadConfig: (config as any).system_parameters,
|
||||
},
|
||||
supportFeedback: true,
|
||||
opening_statement: currentConversationId ? currentConversationItem?.introduction : (config as any).opening_statement,
|
||||
opening_statement: currentConversationItem?.introduction || (config as any).opening_statement,
|
||||
} as ChatConfig
|
||||
}, [appParams, currentConversationItem?.introduction, currentConversationId])
|
||||
}, [appParams, currentConversationItem?.introduction])
|
||||
const {
|
||||
chatList,
|
||||
setTargetMessageId,
|
||||
|
|
@ -158,8 +158,9 @@ const ChatWrapper = () => {
|
|||
}, [chatList, doSend])
|
||||
|
||||
const messageList = useMemo(() => {
|
||||
if (currentConversationId)
|
||||
if (currentConversationId || chatList.length > 1)
|
||||
return chatList
|
||||
// Without messages we are in the welcome screen, so hide the opening statement from chatlist
|
||||
return chatList.filter(item => !item.isOpeningStatement)
|
||||
}, [chatList, currentConversationId])
|
||||
|
||||
|
|
@ -240,7 +241,7 @@ const ChatWrapper = () => {
|
|||
config={appConfig}
|
||||
chatList={messageList}
|
||||
isResponding={respondingState}
|
||||
chatContainerInnerClassName={cn('mx-auto w-full max-w-full pt-4 tablet:px-4', isMobile && 'px-4')}
|
||||
chatContainerInnerClassName={cn('mx-auto w-full max-w-full px-4', messageList.length && 'pt-4')}
|
||||
chatFooterClassName={cn('pb-4', !isMobile && 'rounded-b-2xl')}
|
||||
chatFooterInnerClassName={cn('mx-auto w-full max-w-full px-4', isMobile && 'px-2')}
|
||||
onSend={doSend}
|
||||
|
|
|
|||
|
|
@ -17,12 +17,9 @@ import type {
|
|||
import { noop } from 'lodash-es'
|
||||
|
||||
export type EmbeddedChatbotContextValue = {
|
||||
userCanAccess?: boolean
|
||||
appInfoError?: any
|
||||
appInfoLoading?: boolean
|
||||
appMeta?: AppMeta
|
||||
appData?: AppData
|
||||
appParams?: ChatConfig
|
||||
appMeta: AppMeta | null
|
||||
appData: AppData | null
|
||||
appParams: ChatConfig | null
|
||||
appChatListDataLoading?: boolean
|
||||
currentConversationId: string
|
||||
currentConversationItem?: ConversationItem
|
||||
|
|
@ -59,7 +56,10 @@ export type EmbeddedChatbotContextValue = {
|
|||
}
|
||||
|
||||
export const EmbeddedChatbotContext = createContext<EmbeddedChatbotContextValue>({
|
||||
userCanAccess: false,
|
||||
appData: null,
|
||||
appMeta: null,
|
||||
appParams: null,
|
||||
appChatListDataLoading: false,
|
||||
currentConversationId: '',
|
||||
appPrevChatList: [],
|
||||
pinnedConversationList: [],
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ const Header: FC<IHeaderProps> = ({
|
|||
return (
|
||||
<div
|
||||
className={cn('flex h-14 shrink-0 items-center justify-between rounded-t-2xl px-3')}
|
||||
style={Object.assign({}, CssTransform(theme?.backgroundHeaderColorStyle ?? ''), CssTransform(theme?.headerBorderBottomStyle ?? ''))}
|
||||
style={CssTransform(theme?.headerBorderBottomStyle ?? '')}
|
||||
>
|
||||
<div className="flex grow items-center space-x-3">
|
||||
{customerIcon}
|
||||
|
|
|
|||
|
|
@ -18,9 +18,6 @@ import { CONVERSATION_ID_INFO } from '../constants'
|
|||
import { buildChatItemTree, getProcessedInputsFromUrlParams, getProcessedSystemVariablesFromUrlParams, getProcessedUserVariablesFromUrlParams } from '../utils'
|
||||
import { getProcessedFilesFromResponse } from '../../file-uploader/utils'
|
||||
import {
|
||||
fetchAppInfo,
|
||||
fetchAppMeta,
|
||||
fetchAppParams,
|
||||
fetchChatList,
|
||||
fetchConversations,
|
||||
generationConversationName,
|
||||
|
|
@ -36,8 +33,7 @@ import { InputVarType } from '@/app/components/workflow/types'
|
|||
import { TransferMethod } from '@/types/app'
|
||||
import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils'
|
||||
import { noop } from 'lodash-es'
|
||||
import { useGetUserCanAccessApp } from '@/service/access-control'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useWebAppStore } from '@/context/web-app-context'
|
||||
|
||||
function getFormattedChatList(messages: any[]) {
|
||||
const newChatList: ChatItem[] = []
|
||||
|
|
@ -67,18 +63,10 @@ function getFormattedChatList(messages: any[]) {
|
|||
|
||||
export const useEmbeddedChatbot = () => {
|
||||
const isInstalledApp = false
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const { data: appInfo, isLoading: appInfoLoading, error: appInfoError } = useSWR('appInfo', fetchAppInfo)
|
||||
const { isPending: isCheckingPermission, data: userCanAccessResult } = useGetUserCanAccessApp({
|
||||
appId: appInfo?.app_id,
|
||||
isInstalledApp,
|
||||
enabled: systemFeatures.webapp_auth.enabled,
|
||||
})
|
||||
|
||||
const appData = useMemo(() => {
|
||||
return appInfo
|
||||
}, [appInfo])
|
||||
const appId = useMemo(() => appData?.app_id, [appData])
|
||||
const appInfo = useWebAppStore(s => s.appInfo)
|
||||
const appMeta = useWebAppStore(s => s.appMeta)
|
||||
const appParams = useWebAppStore(s => s.appParams)
|
||||
const appId = useMemo(() => appInfo?.app_id, [appInfo])
|
||||
|
||||
const [userId, setUserId] = useState<string>()
|
||||
const [conversationId, setConversationId] = useState<string>()
|
||||
|
|
@ -145,8 +133,6 @@ export const useEmbeddedChatbot = () => {
|
|||
return currentConversationId
|
||||
}, [currentConversationId, newConversationId])
|
||||
|
||||
const { data: appParams } = useSWR(['appParams', isInstalledApp, appId], () => fetchAppParams(isInstalledApp, appId))
|
||||
const { data: appMeta } = useSWR(['appMeta', isInstalledApp, appId], () => fetchAppMeta(isInstalledApp, appId))
|
||||
const { data: appPinnedConversationData } = useSWR(['appConversationData', isInstalledApp, appId, true], () => fetchConversations(isInstalledApp, appId, undefined, true, 100))
|
||||
const { data: appConversationData, isLoading: appConversationDataLoading, mutate: mutateAppConversationData } = useSWR(['appConversationData', isInstalledApp, appId, false], () => fetchConversations(isInstalledApp, appId, undefined, false, 100))
|
||||
const { data: appChatListData, isLoading: appChatListDataLoading } = useSWR(chatShouldReloadKey ? ['appChatList', chatShouldReloadKey, isInstalledApp, appId] : null, () => fetchChatList(chatShouldReloadKey, isInstalledApp, appId))
|
||||
|
|
@ -398,16 +384,13 @@ export const useEmbeddedChatbot = () => {
|
|||
}, [isInstalledApp, appId, t, notify])
|
||||
|
||||
return {
|
||||
appInfoError,
|
||||
appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && isCheckingPermission),
|
||||
userCanAccess: systemFeatures.webapp_auth.enabled ? userCanAccessResult?.result : true,
|
||||
isInstalledApp,
|
||||
allowResetChat,
|
||||
appId,
|
||||
currentConversationId,
|
||||
currentConversationItem,
|
||||
handleConversationIdInfoChange,
|
||||
appData,
|
||||
appData: appInfo,
|
||||
appParams: appParams || {} as ChatConfig,
|
||||
appMeta,
|
||||
appPinnedConversationData,
|
||||
|
|
|
|||
|
|
@ -49,8 +49,8 @@ const Chatbot = () => {
|
|||
<div className='relative'>
|
||||
<div
|
||||
className={cn(
|
||||
'flex flex-col rounded-2xl border border-components-panel-border-subtle',
|
||||
isMobile ? 'h-[calc(100vh_-_60px)] border-[0.5px] border-components-panel-border shadow-xs' : 'h-[100vh] bg-chatbot-bg',
|
||||
'flex flex-col rounded-2xl',
|
||||
isMobile ? 'h-[calc(100vh_-_60px)] shadow-xs' : 'h-[100vh] bg-chatbot-bg',
|
||||
)}
|
||||
style={isMobile ? Object.assign({}, CssTransform(themeBuilder?.theme?.backgroundHeaderColorStyle ?? '')) : {}}
|
||||
>
|
||||
|
|
@ -62,7 +62,7 @@ const Chatbot = () => {
|
|||
theme={themeBuilder?.theme}
|
||||
onCreateNewChat={handleNewConversation}
|
||||
/>
|
||||
<div className={cn('flex grow flex-col overflow-y-auto', isMobile && '!h-[calc(100vh_-_3rem)] rounded-2xl bg-chatbot-bg')}>
|
||||
<div className={cn('flex grow flex-col overflow-y-auto', isMobile && 'm-[0.5px] !h-[calc(100vh_-_3rem)] rounded-2xl bg-chatbot-bg')}>
|
||||
{appChatListDataLoading && (
|
||||
<Loading type='app' />
|
||||
)}
|
||||
|
|
@ -101,7 +101,6 @@ const EmbeddedChatbotWrapper = () => {
|
|||
|
||||
const {
|
||||
appData,
|
||||
userCanAccess,
|
||||
appParams,
|
||||
appMeta,
|
||||
appChatListDataLoading,
|
||||
|
|
@ -135,7 +134,6 @@ const EmbeddedChatbotWrapper = () => {
|
|||
} = useEmbeddedChatbot()
|
||||
|
||||
return <EmbeddedChatbotContext.Provider value={{
|
||||
userCanAccess,
|
||||
appData,
|
||||
appParams,
|
||||
appMeta,
|
||||
|
|
|
|||
Loading…
Reference in New Issue