mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
6b6ab5e034
|
|
@ -1521,6 +1521,14 @@ def transform_datasource_credentials():
|
|||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not firecrawl_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
|
|
@ -1576,6 +1584,14 @@ def transform_datasource_credentials():
|
|||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not jina_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
|
|
|
|||
|
|
@ -189,6 +189,11 @@ class PluginConfig(BaseSettings):
|
|||
default="plugin-api-key",
|
||||
)
|
||||
|
||||
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||
default=300.0,
|
||||
)
|
||||
|
||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||
|
||||
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import flask_restx
|
||||
from flask_login import current_user
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import select
|
||||
|
|
@ -8,7 +8,7 @@ from werkzeug.exceptions import Forbidden
|
|||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
|
|
@ -57,7 +57,9 @@ class BaseApiKeyListResource(Resource):
|
|||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
|
|
@ -69,8 +71,9 @@ class BaseApiKeyListResource(Resource):
|
|||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
if not current_user.is_editor:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = (
|
||||
|
|
@ -89,7 +92,7 @@ class BaseApiKeyListResource(Resource):
|
|||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_user.current_tenant_id
|
||||
api_token.tenant_id = current_tenant_id
|
||||
api_token.token = key
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
|
|
@ -108,7 +111,8 @@ class BaseApiKeyResource(Resource):
|
|||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
api_key_id = str(api_key_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
|
|
@ -152,7 +156,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
|||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
def after_request(self, resp: Response):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
|
@ -202,7 +206,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
|||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
def after_request(self, resp: Response):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
|
@ -223,7 +227,7 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
|||
"""Delete an API key for a dataset"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
def after_request(self, resp: Response):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -17,7 +16,7 @@ from fields.annotation_fields import (
|
|||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
|
|
@ -43,7 +42,9 @@ class AnnotationReplyActionApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -70,7 +71,9 @@ class AppAnnotationSettingDetailApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -99,7 +102,9 @@ class AppAnnotationSettingUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id, annotation_setting_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -125,7 +130,9 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
def get(self, app_id, job_id, action):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
|
|
@ -160,7 +167,9 @@ class AnnotationApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
|
|
@ -199,7 +208,9 @@ class AnnotationApi(Resource):
|
|||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -214,7 +225,9 @@ class AnnotationApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -250,7 +263,9 @@ class AnnotationExportApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -273,7 +288,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -289,7 +306,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -311,7 +330,9 @@ class AnnotationBatchImportApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
|
@ -342,7 +363,9 @@ class AnnotationBatchImportStatusApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
def get(self, app_id, job_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
|
|
@ -377,7 +400,9 @@ class AnnotationHitHistoryListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,3 @@
|
|||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
|
@ -13,8 +10,7 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
|
@ -32,7 +28,8 @@ class AppImportApi(Resource):
|
|||
@cloud_edition_billing_resource_check("apps")
|
||||
def post(self):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -51,7 +48,7 @@ class AppImportApi(Resource):
|
|||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = cast(Account, current_user)
|
||||
account = current_user
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
|
|
@ -85,14 +82,15 @@ class AppImportConfirmApi(Resource):
|
|||
@marshal_with(app_import_fields)
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
account = cast(Account, current_user)
|
||||
account = current_user
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
|
|
@ -110,7 +108,8 @@ class AppImportCheckDependenciesApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(app_import_check_dependencies_fields)
|
||||
def get(self, app_model: App):
|
||||
if not current_user.is_editor:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
|
|
@ -17,7 +16,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
|
|||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
|
@ -48,11 +47,11 @@ class RuleGenerateApi(Resource):
|
|||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=account.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=args["no_variable"],
|
||||
|
|
@ -99,11 +98,11 @@ class RuleCodeGenerateApi(Resource):
|
|||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=account.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["code_language"],
|
||||
|
|
@ -144,11 +143,11 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=account.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
)
|
||||
|
|
@ -198,6 +197,7 @@ class InstructionGenerateApi(Resource):
|
|||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
code_template = (
|
||||
Python3CodeProvider.get_default_code()
|
||||
if args["language"] == "python"
|
||||
|
|
@ -222,21 +222,21 @@ class InstructionGenerateApi(Resource):
|
|||
match node_type:
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=True,
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=True,
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["language"],
|
||||
|
|
@ -245,7 +245,7 @@ class InstructionGenerateApi(Resource):
|
|||
return {"error": f"invalid node type: {node_type}"}
|
||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
||||
return LLMGenerator.instruction_modify_legacy(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
|
|
@ -254,7 +254,7 @@ class InstructionGenerateApi(Resource):
|
|||
)
|
||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
||||
return LLMGenerator.instruction_modify_workflow(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
node_id=args["node_id"],
|
||||
current=args["current"],
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import json
|
|||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -15,8 +14,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
|
|||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
|
@ -54,16 +52,14 @@ class ModelConfigResource(Resource):
|
|||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model):
|
||||
"""Modify app model config"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
config=cast(dict, request.json),
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
|
@ -95,12 +91,12 @@ class ModelConfigResource(Resource):
|
|||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
|
|
@ -134,7 +130,7 @@ class ModelConfigResource(Resource):
|
|||
else:
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
|
|
@ -142,7 +138,7 @@ class ModelConfigResource(Resource):
|
|||
continue
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
|
@ -9,7 +8,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account, Site
|
||||
|
||||
|
||||
|
|
@ -76,9 +75,10 @@ class AppSite(Resource):
|
|||
@marshal_with(app_site_fields)
|
||||
def post(self, app_model):
|
||||
args = parse_app_site_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
|
@ -131,6 +131,8 @@ class AppSiteAccessTokenReset(Resource):
|
|||
@marshal_with(app_site_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
|
@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
"sources": [
|
||||
|
|
@ -41,6 +41,8 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
|||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -50,7 +52,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
|||
args = parser.parse_args()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
|
|
@ -63,9 +65,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
|||
@account_initialization_required
|
||||
def delete(self, binding_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import console_ns
|
||||
|
|
@ -17,6 +17,8 @@ class ComplianceApi(Resource):
|
|||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("doc_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from collections.abc import Generator
|
|||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -20,7 +19,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
|
|||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
|
@ -37,10 +36,12 @@ class DataSourceApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(integrate_list_fields)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.tenant_id == current_tenant_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
)
|
||||
).all()
|
||||
|
|
@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(integrate_notion_info_list_fields)
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
|
|
@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource):
|
|||
documents = session.scalars(
|
||||
select(Document).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
)
|
||||
|
|
@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource):
|
|||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id="langgenius/notion_datasource/notion_datasource",
|
||||
datasource_name="notion_datasource",
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
|
@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
|
|
@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource):
|
|||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
text_docs = extractor.extract()
|
||||
|
|
@ -239,6 +244,8 @@ class DataSourceNotionApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
|
|
@ -263,7 +270,7 @@ class DataSourceNotionApi(Resource):
|
|||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
|
|
@ -271,7 +278,7 @@ class DataSourceNotionApi(Resource):
|
|||
extract_settings.append(extract_setting)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
extract_settings,
|
||||
args["process_rule"],
|
||||
args["doc_form"],
|
||||
|
|
|
|||
|
|
@ -45,6 +45,79 @@ def _validate_name(name: str) -> str:
|
|||
return name
|
||||
|
||||
|
||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get supported retrieval methods based on vector database type.
|
||||
|
||||
Args:
|
||||
vector_type: Vector database type, can be None
|
||||
is_mock: Whether this is a Mock API, affects MILVUS handling
|
||||
|
||||
Returns:
|
||||
Dictionary containing supported retrieval methods
|
||||
|
||||
Raises:
|
||||
ValueError: If vector_type is None or unsupported
|
||||
"""
|
||||
if vector_type is None:
|
||||
raise ValueError("Vector store type is not configured.")
|
||||
|
||||
# Define vector database types that only support semantic search
|
||||
semantic_only_types = {
|
||||
VectorType.RELYT,
|
||||
VectorType.TIDB_VECTOR,
|
||||
VectorType.CHROMA,
|
||||
VectorType.PGVECTO_RS,
|
||||
VectorType.VIKINGDB,
|
||||
VectorType.UPSTASH,
|
||||
}
|
||||
|
||||
# Define vector database types that support all retrieval methods
|
||||
full_search_types = {
|
||||
VectorType.QDRANT,
|
||||
VectorType.WEAVIATE,
|
||||
VectorType.OPENSEARCH,
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.MYSCALE,
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
VectorType.ELASTICSEARCH_JA,
|
||||
VectorType.PGVECTOR,
|
||||
VectorType.VASTBASE,
|
||||
VectorType.TIDB_ON_QDRANT,
|
||||
VectorType.LINDORM,
|
||||
VectorType.COUCHBASE,
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.OCEANBASE,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.HUAWEI_CLOUD,
|
||||
VectorType.TENCENT,
|
||||
VectorType.MATRIXONE,
|
||||
VectorType.CLICKZETTA,
|
||||
VectorType.BAIDU,
|
||||
VectorType.ALIBABACLOUD_MYSQL,
|
||||
}
|
||||
|
||||
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
full_methods = {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
]
|
||||
}
|
||||
|
||||
if vector_type == VectorType.MILVUS:
|
||||
return semantic_methods if is_mock else full_methods
|
||||
|
||||
if vector_type in semantic_only_types:
|
||||
return semantic_methods
|
||||
elif vector_type in full_search_types:
|
||||
return full_methods
|
||||
else:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
|
||||
|
||||
@console_ns.route("/datasets")
|
||||
class DatasetListApi(Resource):
|
||||
@api.doc("get_datasets")
|
||||
|
|
@ -777,50 +850,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.VASTBASE
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.MILVUS
|
||||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
| VectorType.ALIBABACLOUD_MYSQL
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH,
|
||||
RetrievalMethod.HYBRID_SEARCH,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
|
||||
|
|
@ -833,49 +863,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.VASTBASE
|
||||
| VectorType.LINDORM
|
||||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.TENCENT
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
| VectorType.ALIBABACLOUD_MYSQL
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH,
|
||||
RetrievalMethod.HYBRID_SEARCH,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import uuid
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
|
@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
|
|
@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -79,7 +80,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id,
|
||||
DocumentSegment.tenant_id == current_tenant_id,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
)
|
||||
|
|
@ -115,6 +116,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -148,6 +151,8 @@ class DatasetDocumentSegmentApi(Resource):
|
|||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
|
|
@ -171,7 +176,7 @@ class DatasetDocumentSegmentApi(Resource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -204,6 +209,8 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -221,7 +228,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -255,6 +262,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -272,7 +281,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -287,7 +296,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -317,6 +326,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -333,7 +344,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -361,6 +372,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -396,7 +409,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||
upload_file_id,
|
||||
dataset_id,
|
||||
document_id,
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
current_user.id,
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -427,6 +440,8 @@ class ChildChunkAddApi(Resource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -441,7 +456,7 @@ class ChildChunkAddApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -453,7 +468,7 @@ class ChildChunkAddApi(Resource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -483,6 +498,8 @@ class ChildChunkAddApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id, segment_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -499,7 +516,7 @@ class ChildChunkAddApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -530,6 +547,8 @@ class ChildChunkAddApi(Resource):
|
|||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -546,7 +565,7 @@ class ChildChunkAddApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -580,6 +599,8 @@ class ChildChunkUpdateApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -596,7 +617,7 @@ class ChildChunkUpdateApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -607,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
|
|||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
|
|
@ -634,6 +655,8 @@ class ChildChunkUpdateApi(Resource):
|
|||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
|
@ -650,7 +673,7 @@ class ChildChunkUpdateApi(Resource):
|
|||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
@ -661,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
|
|||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
|
|
@ -21,6 +19,7 @@ from core.errors.error import (
|
|||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
|
@ -31,6 +30,7 @@ logger = logging.getLogger(__name__)
|
|||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
|
@ -57,11 +57,12 @@ class DatasetsHitTestingBase:
|
|||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
assert isinstance(current_user, Account)
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=cast(Account, current_user),
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from flask import make_response, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
|
|
@ -13,7 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
|
@ -25,9 +24,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
if not current_user.is_editor:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
credential_id = request.args.get("credential_id")
|
||||
|
|
@ -52,7 +52,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
|||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=current_user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
|
|
@ -131,7 +131,9 @@ class DatasourceAuth(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -145,7 +147,7 @@ class DatasourceAuth(Resource):
|
|||
|
||||
try:
|
||||
datasource_provider_service.add_datasource_api_key_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
|
|
@ -160,8 +162,10 @@ class DatasourceAuth(Resource):
|
|||
def get(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasources = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
|
|
@ -174,17 +178,19 @@ class DatasourceAuthDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
|
|
@ -198,17 +204,19 @@ class DatasourceAuthUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
|
|
@ -224,10 +232,10 @@ class DatasourceAuthListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
|
|
@ -237,10 +245,10 @@ class DatasourceHardCodeAuthListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
|
|
@ -250,7 +258,9 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
|
|
@ -259,7 +269,7 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
client_params=args.get("client_params", {}),
|
||||
enabled=args.get("enable_oauth_custom_client", False),
|
||||
|
|
@ -270,10 +280,12 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
|
@ -285,7 +297,9 @@ class DatasourceAuthDefaultApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
|
|
@ -293,7 +307,7 @@ class DatasourceAuthDefaultApi(Resource):
|
|||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
credential_id=args["id"],
|
||||
)
|
||||
|
|
@ -306,7 +320,9 @@ class DatasourceUpdateProviderNameApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
|
|
@ -315,7 +331,7 @@ class DatasourceUpdateProviderNameApi(Resource):
|
|||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
name=args["name"],
|
||||
credential_id=args["credential_id"],
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
|
@ -13,7 +12,7 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
|
|
@ -38,7 +37,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
|
@ -58,12 +57,12 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
import_info["dataset_id"],
|
||||
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||
)
|
||||
|
|
@ -81,10 +80,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
|||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
||||
name="",
|
||||
description="",
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
|||
from extensions.ext_database import db
|
||||
from fields.installed_app_fields import installed_app_list_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, InstalledApp, RecommendedApp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
|
@ -29,9 +29,7 @@ class InstalledAppsListApi(Resource):
|
|||
@marshal_with(installed_app_list_fields)
|
||||
def get(self):
|
||||
app_id = request.args.get("app_id", default=None, type=str)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if app_id:
|
||||
installed_apps = db.session.scalars(
|
||||
|
|
@ -121,9 +119,8 @@ class InstalledAppsListApi(Resource):
|
|||
if recommended_app is None:
|
||||
raise NotFound("App not found")
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
||||
|
||||
if app is None:
|
||||
|
|
@ -163,9 +160,8 @@ class InstalledAppApi(InstalledAppResource):
|
|||
"""
|
||||
|
||||
def delete(self, installed_app):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
|
||||
db.session.delete(installed_app)
|
||||
|
|
|
|||
|
|
@ -2,15 +2,15 @@ from collections.abc import Callable
|
|||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import InstalledApp
|
||||
from models.account import Account
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
|
@ -24,6 +24,8 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
|||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.where(
|
||||
|
|
@ -56,6 +58,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
|||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||
feature = FeatureService.get_system_features()
|
||||
if feature.webapp_auth.enabled:
|
||||
assert isinstance(current_user, Account)
|
||||
app_id = installed_app.app_id
|
||||
app_code = AppService.get_app_code_by_id(app_id)
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from models.account import Account
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
|
@ -47,7 +47,7 @@ class APIBasedExtensionAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
|
||||
@api.doc("create_api_based_extension")
|
||||
|
|
@ -68,14 +68,17 @@ class APIBasedExtensionAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||
parser.add_argument("api_key", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
name=args["name"],
|
||||
api_endpoint=args["api_endpoint"],
|
||||
api_key=args["api_key"],
|
||||
|
|
@ -95,8 +98,10 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def get(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
|
||||
|
|
@ -119,10 +124,12 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
|
|
@ -146,10 +153,12 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
APIBasedExtensionService.delete(extension_data_from_db)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api, console_ns
|
||||
|
|
@ -23,7 +22,9 @@ class FeatureApi(Resource):
|
|||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get feature configuration for current tenant"""
|
||||
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_features(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/system-features")
|
||||
|
|
|
|||
|
|
@ -108,4 +108,4 @@ class FileSupportTypeApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
|
||||
return {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
|
|
@ -16,7 +14,7 @@ from core.file import helpers as file_helpers
|
|||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
|
|
@ -65,7 +63,7 @@ class RemoteFileUploadApi(Resource):
|
|||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
user = cast(Account, current_user)
|
||||
user, _ = current_account_with_tenant()
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.model import Tag
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
|
@ -24,6 +24,8 @@ class TagListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(dataset_tag_fields)
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_type = request.args.get("type", type=str, default="")
|
||||
keyword = request.args.get("keyword", default=None, type=str)
|
||||
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
||||
|
|
@ -34,8 +36,10 @@ class TagListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -59,9 +63,11 @@ class TagUpdateDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -81,9 +87,11 @@ class TagUpdateDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, tag_id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
|
@ -97,8 +105,10 @@ class TagBindingCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -123,8 +133,10 @@ class TagBindingDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
|
|
@ -21,7 +21,9 @@ class AgentProviderListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
|
@ -43,7 +45,9 @@ class AgentProviderApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -6,7 +5,7 @@ from controllers.console import api, console_ns
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
|
|
@ -34,7 +33,7 @@ class EndpointCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -51,7 +50,7 @@ class EndpointCreateApi(Resource):
|
|||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
|
|
@ -80,7 +79,7 @@ class EndpointListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -93,7 +92,7 @@ class EndpointListApi(Resource):
|
|||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
|
|
@ -123,7 +122,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -138,7 +137,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
|
|
@ -165,7 +164,7 @@ class EndpointDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -177,9 +176,7 @@ class EndpointDeleteApi(Resource):
|
|||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -207,7 +204,7 @@ class EndpointUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -224,7 +221,7 @@ class EndpointUpdateApi(Resource):
|
|||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
|
|
@ -250,7 +247,7 @@ class EndpointEnableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -262,9 +259,7 @@ class EndpointEnableApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -285,7 +280,7 @@ class EndpointDisableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -297,7 +292,5 @@ class EndpointDisableApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from urllib import parse
|
||||
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
|
|
@ -26,7 +25,7 @@ from controllers.console.wraps import (
|
|||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_list_fields
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
|
@ -42,8 +41,7 @@ class MemberListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_fields)
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
|
|
@ -70,9 +68,7 @@ class MemberInviteEmailApi(Resource):
|
|||
interface_language = args["language"]
|
||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
inviter = current_user
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
|
|
@ -121,8 +117,7 @@ class MemberCancelInviteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, member_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
member = db.session.query(Account).where(Account.id == str(member_id)).first()
|
||||
|
|
@ -161,9 +156,7 @@ class MemberUpdateRoleApi(Resource):
|
|||
|
||||
if not TenantAccountRole.is_valid_role(new_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
member = db.session.get(Account, str(member_id))
|
||||
|
|
@ -190,8 +183,7 @@ class DatasetOperatorMemberListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_fields)
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
|
|
@ -213,10 +205,8 @@ class SendOwnerTransferEmailApi(Resource):
|
|||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# check if the current user is the owner of the workspace
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
|
|
@ -251,8 +241,7 @@ class OwnerTransferCheckApi(Resource):
|
|||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
# check if the current user is the owner of the workspace
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
|
|
@ -297,8 +286,7 @@ class OwnerTransfer(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
# check if the current user is the owner of the workspace
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import io
|
||||
|
||||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -11,8 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
|
@ -23,11 +21,8 @@ class ModelProviderListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
|
|
@ -52,11 +47,8 @@ class ModelProviderCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
# if credential_id is not provided, return current used credential
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
|
|
@ -73,8 +65,7 @@ class ModelProviderCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -85,11 +76,9 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
try:
|
||||
model_provider_service.create_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_name=args["name"],
|
||||
|
|
@ -103,8 +92,7 @@ class ModelProviderCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -116,11 +104,9 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
try:
|
||||
model_provider_service.update_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_id=args["credential_id"],
|
||||
|
|
@ -135,19 +121,16 @@ class ModelProviderCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||
tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
|
@ -159,19 +142,16 @@ class ModelProviderCredentialSwitchApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
service = ModelProviderService()
|
||||
service.switch_active_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
|
|
@ -184,15 +164,12 @@ class ModelProviderValidateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant_id = current_user.current_tenant_id
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
|
|
@ -240,14 +217,11 @@ class PreferredProviderTypeUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
tenant_id = current_user.current_tenant_id
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
|
|
@ -276,14 +250,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
|||
def get(self, provider: str):
|
||||
if provider != "anthropic":
|
||||
raise ValueError(f"provider name {provider} is invalid")
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
if not current_user.current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
data = BillingService.get_model_provider_payment_link(
|
||||
provider_name=provider,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
account_id=current_user.id,
|
||||
prefilled_email=current_user.email,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -10,7 +9,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
|
@ -23,6 +22,8 @@ class DefaultModelApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"model_type",
|
||||
|
|
@ -34,8 +35,6 @@ class DefaultModelApi(Resource):
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
default_model_entity = model_provider_service.get_default_model_of_model_type(
|
||||
tenant_id=tenant_id, model_type=args["model_type"]
|
||||
|
|
@ -47,15 +46,14 @@ class DefaultModelApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_settings = args["model_settings"]
|
||||
for model_setting in model_settings:
|
||||
|
|
@ -92,7 +90,7 @@ class ModelProviderModelApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
|
@ -104,11 +102,11 @@ class ModelProviderModelApi(Resource):
|
|||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
# To save the model's load balance configs
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
|
|
@ -129,7 +127,7 @@ class ModelProviderModelApi(Resource):
|
|||
raise ValueError("credential_id is required when configuring a custom-model")
|
||||
service = ModelProviderService()
|
||||
service.switch_active_custom_model_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
|
|
@ -164,11 +162,11 @@ class ModelProviderModelApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
|
|
@ -195,7 +193,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||
|
|
@ -257,6 +255,8 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -274,7 +274,6 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
|
|
@ -301,6 +300,8 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -323,7 +324,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
|
||||
try:
|
||||
model_provider_service.update_model_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
|
|
@ -340,6 +341,8 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -357,7 +360,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_model_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
|
|
@ -373,6 +376,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -390,7 +395,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
|||
|
||||
service = ModelProviderService()
|
||||
service.add_model_credential_to_model_list(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=provider,
|
||||
model_type=args["model_type"],
|
||||
model=args["model"],
|
||||
|
|
@ -407,7 +412,7 @@ class ModelProviderModelEnableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
|
|
@ -437,7 +442,7 @@ class ModelProviderModelDisableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
|
|
@ -465,7 +470,7 @@ class ModelProviderModelValidateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
|
|
@ -514,8 +519,7 @@ class ModelProviderModelParameterRuleApi(Resource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||
|
|
@ -531,8 +535,7 @@ class ModelProviderAvailableModelApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, model_type):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import io
|
||||
|
||||
from flask import request, send_file
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -11,7 +10,7 @@ from controllers.console.workspace import plugin_permission_required
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_parameter_service import PluginParameterService
|
||||
|
|
@ -26,7 +25,7 @@ class PluginDebuggingKeyApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {
|
||||
|
|
@ -44,7 +43,7 @@ class PluginListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=False, location="args", default=1)
|
||||
parser.add_argument("page_size", type=int, required=False, location="args", default=256)
|
||||
|
|
@ -81,7 +80,7 @@ class PluginListInstallationsFromIdsApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
|
|
@ -120,7 +119,7 @@ class PluginUploadFromPkgApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
file = request.files["pkg"]
|
||||
|
||||
|
|
@ -144,7 +143,7 @@ class PluginUploadFromGithubApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("repo", type=str, required=True, location="json")
|
||||
|
|
@ -167,7 +166,7 @@ class PluginUploadFromBundleApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
file = request.files["bundle"]
|
||||
|
||||
|
|
@ -191,7 +190,7 @@ class PluginInstallFromPkgApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||
|
|
@ -217,7 +216,7 @@ class PluginInstallFromGithubApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("repo", type=str, required=True, location="json")
|
||||
|
|
@ -247,7 +246,7 @@ class PluginInstallFromMarketplaceApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||
|
|
@ -273,7 +272,7 @@ class PluginFetchMarketplacePkgApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
|
|
@ -299,7 +298,7 @@ class PluginFetchManifestApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
|
|
@ -324,7 +323,7 @@ class PluginFetchInstallTasksApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -346,7 +345,7 @@ class PluginFetchInstallTaskApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self, task_id: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
|
||||
|
|
@ -361,7 +360,7 @@ class PluginDeleteInstallTaskApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self, task_id: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
|
||||
|
|
@ -376,7 +375,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
|
||||
|
|
@ -391,7 +390,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self, task_id: str, identifier: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
|
||||
|
|
@ -406,7 +405,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
|
|
@ -430,7 +429,7 @@ class PluginUpgradeFromGithubApi(Resource):
|
|||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
|
|
@ -466,7 +465,7 @@ class PluginUninstallApi(Resource):
|
|||
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
|
||||
|
|
@ -480,6 +479,7 @@ class PluginChangePermissionApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
|
@ -492,7 +492,7 @@ class PluginChangePermissionApi(Resource):
|
|||
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
||||
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
||||
|
||||
|
|
@ -503,7 +503,7 @@ class PluginFetchPermissionApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
permission = PluginPermissionService.get_permission(tenant_id)
|
||||
if not permission:
|
||||
|
|
@ -529,10 +529,10 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
# check if the user is admin or owner
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = current_user.id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -565,7 +565,7 @@ class PluginChangePreferencesApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -574,8 +574,6 @@ class PluginChangePreferencesApi(Resource):
|
|||
req.add_argument("auto_upgrade", type=dict, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
permission = args["permission"]
|
||||
|
||||
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
|
||||
|
|
@ -621,7 +619,7 @@ class PluginFetchPreferencesApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
permission = PluginPermissionService.get_permission(tenant_id)
|
||||
permission_dict = {
|
||||
|
|
@ -661,7 +659,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
|||
@account_initialization_required
|
||||
def post(self):
|
||||
# exclude one single plugin
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("plugin_id", type=str, required=True, location="json")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import io
|
|||
from urllib.parse import urlparse
|
||||
|
||||
from flask import make_response, redirect, request, send_file
|
||||
from flask_login import current_user
|
||||
from flask_restx import (
|
||||
Resource,
|
||||
reqparse,
|
||||
|
|
@ -24,7 +23,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
|
|
@ -53,10 +52,9 @@ class ToolProviderListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument(
|
||||
|
|
@ -78,9 +76,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user = current_user
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.list_builtin_tool_provider_tools(
|
||||
|
|
@ -96,9 +92,7 @@ class ToolBuiltinProviderInfoApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user = current_user
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||
|
||||
|
|
@ -109,11 +103,10 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = req.parse_args()
|
||||
|
|
@ -131,10 +124,9 @@ class ToolBuiltinProviderAddApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
|
|
@ -161,13 +153,12 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
|
|
@ -193,7 +184,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||
|
|
@ -218,13 +209,12 @@ class ToolApiProviderAddApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
|
|
@ -258,10 +248,9 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
|
|
@ -282,10 +271,9 @@ class ToolApiProviderListToolsApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
|
|
@ -308,13 +296,12 @@ class ToolApiProviderUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
|
|
@ -350,13 +337,12 @@ class ToolApiProviderDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
|
|
@ -377,10 +363,9 @@ class ToolApiProviderGetApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
|
|
@ -401,8 +386,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, credential_type):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.list_builtin_provider_credentials_schema(
|
||||
|
|
@ -444,9 +428,9 @@ class ToolApiProviderPreviousTestApi(Resource):
|
|||
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return ApiToolManageService.test_api_tool_preview(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
args["provider_name"] or "",
|
||||
args["tool_name"],
|
||||
args["credentials"],
|
||||
|
|
@ -462,13 +446,12 @@ class ToolWorkflowProviderCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
|
|
@ -502,13 +485,12 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
|
|
@ -545,13 +527,12 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
|
|
@ -571,10 +552,9 @@ class ToolWorkflowProviderGetApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
|
|
@ -606,10 +586,9 @@ class ToolWorkflowProviderListToolApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
|
||||
|
|
@ -631,10 +610,9 @@ class ToolBuiltinListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
|
|
@ -653,8 +631,7 @@ class ToolApiListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
|
|
@ -672,10 +649,9 @@ class ToolWorkflowListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
|
|
@ -709,19 +685,18 @@ class ToolPluginOAuthApi(Resource):
|
|||
provider_name = tool_provider.provider_name
|
||||
|
||||
# todo check permission
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("no oauth available client config found for this tool provider")
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
|
||||
user_id=user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
|
||||
)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
|
|
@ -800,11 +775,12 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -819,13 +795,13 @@ class ToolOAuthCustomClient(Resource):
|
|||
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return BuiltinToolManageService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
client_params=args.get("client_params", {}),
|
||||
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
|
||||
|
|
@ -835,20 +811,18 @@ class ToolOAuthCustomClient(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_custom_oauth_client_params(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider
|
||||
)
|
||||
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.delete_custom_oauth_client_params(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider
|
||||
)
|
||||
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -858,9 +832,10 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
|
||||
tenant_id=current_user.current_tenant_id, provider_name=provider
|
||||
tenant_id=current_tenant_id, provider_name=provider
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -871,7 +846,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
|
||||
|
|
@ -900,12 +875,12 @@ class ToolProviderMCPApi(Resource):
|
|||
)
|
||||
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
args = parser.parse_args()
|
||||
user = current_user
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
return jsonable_encoder(
|
||||
MCPToolManageService.create_mcp_provider(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
|
|
@ -940,8 +915,9 @@ class ToolProviderMCPApi(Resource):
|
|||
pass
|
||||
else:
|
||||
raise ValueError("Server URL is not valid.")
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
MCPToolManageService.update_mcp_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=args["provider_id"],
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
|
|
@ -962,7 +938,8 @@ class ToolProviderMCPApi(Resource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
|
|
@ -977,7 +954,7 @@ class ToolMCPAuthApi(Resource):
|
|||
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
provider_id = args["provider_id"]
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if not provider:
|
||||
raise ValueError("provider not found")
|
||||
|
|
@ -1018,8 +995,8 @@ class ToolMCPDetailApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
user = current_user
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
|
||||
|
||||
|
|
@ -1029,8 +1006,7 @@ class ToolMCPListAllApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
|
||||
|
||||
|
|
@ -1043,7 +1019,7 @@ class ToolMCPUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
|
@ -24,7 +23,7 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account, Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@ from functools import wraps
|
|||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.account import AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
|
|
@ -29,6 +29,8 @@ def account_initialization_required(view: Callable[P, R]):
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# check account initialization
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
|
||||
if account.status == AccountStatus.UNINITIALIZED:
|
||||
|
|
@ -75,7 +77,8 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
|||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if not features.billing.enabled:
|
||||
abort(403, "Billing feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
|
@ -87,7 +90,8 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
|
|
@ -128,7 +132,8 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
|||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
|
|
@ -151,10 +156,11 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if resource == "knowledge":
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{current_user.current_tenant_id}"
|
||||
key = f"rate_limit_{current_tenant_id}"
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
|
||||
|
|
@ -165,7 +171,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||
if request_count > knowledge_rate_limit.limit:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
|
|
@ -185,14 +191,15 @@ def cloud_utm_record(view: Callable[P, R]):
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: dict = json.loads(utm_info)
|
||||
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
|
||||
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
@ -271,7 +278,8 @@ def enable_change_email(view: Callable[P, R]):
|
|||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.is_allow_transfer_workspace:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
@ -281,10 +289,11 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
|||
return decorated
|
||||
|
||||
|
||||
def knowledge_pipeline_publish_enabled(view):
|
||||
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.knowledge_pipeline.publish_enabled:
|
||||
return view(*args, **kwargs)
|
||||
abort(403)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
|
@ -16,6 +15,7 @@ from core.model_manager import ModelManager
|
|||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
|
||||
|
|
@ -66,6 +66,7 @@ class SegmentApi(DatasetApiResource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Create single segment."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
|
@ -84,7 +85,7 @@ class SegmentApi(DatasetApiResource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -117,6 +118,7 @@ class SegmentApi(DatasetApiResource):
|
|||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Get segments."""
|
||||
# check dataset
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
|
|
@ -133,7 +135,7 @@ class SegmentApi(DatasetApiResource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -149,7 +151,7 @@ class SegmentApi(DatasetApiResource):
|
|||
|
||||
segments, total = SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
status_list=args["status"],
|
||||
keyword=args["keyword"],
|
||||
page=page,
|
||||
|
|
@ -184,6 +186,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
|
|
@ -195,7 +198,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
|
|
@ -217,6 +220,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
|
|
@ -232,7 +236,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -244,7 +248,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
|
|
@ -266,6 +270,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
|
|
@ -277,7 +282,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
|
|
@ -307,6 +312,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Create child chunk."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
|
@ -319,7 +325,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||
raise NotFound("Document not found.")
|
||||
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
|
|
@ -328,7 +334,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
|
|
@ -364,6 +370,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Get child chunks."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
|
@ -376,7 +383,7 @@ class ChildChunkApi(DatasetApiResource):
|
|||
raise NotFound("Document not found.")
|
||||
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
|
|
@ -423,6 +430,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Delete child chunk."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
|
@ -435,7 +443,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
raise NotFound("Document not found.")
|
||||
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
|
|
@ -444,9 +452,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
raise NotFound("Document not found.")
|
||||
|
||||
# check child chunk
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(
|
||||
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
|
|
@ -483,6 +489,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Update child chunk."""
|
||||
# check dataset
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
|
@ -495,7 +502,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
raise NotFound("Document not found.")
|
||||
|
||||
# get segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
|
|
@ -504,9 +511,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
raise NotFound("Segment not found.")
|
||||
|
||||
# get child chunk
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(
|
||||
child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,11 @@ class ModelConfigConverter:
|
|||
if not model_mode:
|
||||
model_mode = LLMMode.CHAT
|
||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
try:
|
||||
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE])
|
||||
except ValueError:
|
||||
# Fall back to CHAT mode if the stored value is invalid
|
||||
model_mode = LLMMode.CHAT
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from cachetools import TTLCache, cachedmethod
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
|
|
@ -45,6 +47,8 @@ class AppQueueManager:
|
|||
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
|
||||
self._q = q
|
||||
self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1)
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
def listen(self):
|
||||
"""
|
||||
|
|
@ -157,6 +161,7 @@ class AppQueueManager:
|
|||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
@cachedmethod(lambda self: self._stopped_cache, lock=lambda self: self._cache_lock)
|
||||
def _is_stopped(self) -> bool:
|
||||
"""
|
||||
Check if task is stopped
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import enum
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
|
@ -218,7 +218,7 @@ class DatasourceLabel(BaseModel):
|
|||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class DatasourceInvokeFrom(Enum):
|
||||
class DatasourceInvokeFrom(StrEnum):
|
||||
"""
|
||||
Enum class for datasource invoke
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1414,7 +1414,7 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type.value == FormType.SECRET_INPUT:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum, auto
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
|
|
@ -7,7 +7,7 @@ from core.model_runtime.entities.common_entities import I18nObject
|
|||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
|
||||
|
||||
class ConfigurateMethod(Enum):
|
||||
class ConfigurateMethod(StrEnum):
|
||||
"""
|
||||
Enum class for configurate method of provider model.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import uuid
|
|||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Final
|
||||
from typing import Final, cast
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
|
@ -199,7 +199,7 @@ def convert_to_trace_id(uuid_v4: str | None) -> int:
|
|||
raise ValueError("UUID cannot be None")
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
return uuid_obj.int
|
||||
return cast(int, uuid_obj.int)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class TracingProviderEnum(StrEnum):
|
|||
OPIK = "opik"
|
||||
WEAVE = "weave"
|
||||
ALIYUN = "aliyun"
|
||||
TENCENT = "tencent"
|
||||
|
||||
|
||||
class BaseTracingConfig(BaseModel):
|
||||
|
|
@ -195,5 +196,32 @@ class AliyunConfig(BaseTracingConfig):
|
|||
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
|
||||
class TencentConfig(BaseTracingConfig):
|
||||
"""
|
||||
Tencent APM tracing config
|
||||
"""
|
||||
|
||||
token: str
|
||||
endpoint: str
|
||||
service_name: str
|
||||
|
||||
@field_validator("token")
|
||||
@classmethod
|
||||
def token_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Token cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
|
||||
|
||||
@field_validator("service_name")
|
||||
@classmethod
|
||||
def service_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ class SuggestedQuestionTraceInfo(BaseTraceInfo):
|
|||
|
||||
class DatasetRetrievalTraceInfo(BaseTraceInfo):
|
||||
documents: Any = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ToolTraceInfo(BaseTraceInfo):
|
||||
|
|
|
|||
|
|
@ -120,6 +120,17 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
|||
"trace_instance": AliyunDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.TENCENT:
|
||||
from core.ops.entities.config_entity import TencentConfig
|
||||
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
|
||||
|
||||
return {
|
||||
"config_class": TencentConfig,
|
||||
"secret_keys": ["token"],
|
||||
"other_keys": ["endpoint", "service_name"],
|
||||
"trace_instance": TencentDataTrace,
|
||||
}
|
||||
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
|
||||
|
|
@ -723,6 +734,7 @@ class TraceTask:
|
|||
end_time=timer.get("end"),
|
||||
metadata=metadata,
|
||||
message_data=message_data.to_dict(),
|
||||
error=kwargs.get("error"),
|
||||
)
|
||||
|
||||
return dataset_retrieval_trace_info
|
||||
|
|
@ -889,6 +901,7 @@ class TraceQueueManager:
|
|||
continue
|
||||
file_id = uuid4().hex
|
||||
trace_info = task.execute()
|
||||
|
||||
task_data = TaskData(
|
||||
app_id=task.app_id,
|
||||
trace_info_type=type(trace_info).__name__,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,337 @@
|
|||
"""
|
||||
Tencent APM Trace Client - handles network operations, metrics, and API communication
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Meter
|
||||
from opentelemetry.metrics._internal.instrument import Histogram
|
||||
from opentelemetry.sdk.metrics.export import MetricReader
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.trace import SpanKind
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
from .entities.tencent_semconv import LLM_OPERATION_DURATION
|
||||
from .entities.tencent_trace_entity import SpanData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TencentTraceClient:
|
||||
"""Tencent APM trace client using OpenTelemetry OTLP exporter"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str,
|
||||
endpoint: str,
|
||||
token: str,
|
||||
max_queue_size: int = 1000,
|
||||
schedule_delay_sec: int = 5,
|
||||
max_export_batch_size: int = 50,
|
||||
metrics_export_interval_sec: int = 10,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.token = token
|
||||
self.service_name = service_name
|
||||
self.metrics_export_interval_sec = metrics_export_interval_sec
|
||||
|
||||
self.resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
}
|
||||
)
|
||||
# Prepare gRPC endpoint/metadata
|
||||
grpc_endpoint, insecure, _, _ = self._resolve_grpc_target(endpoint)
|
||||
|
||||
headers = (("authorization", f"Bearer {token}"),)
|
||||
|
||||
self.exporter = OTLPSpanExporter(
|
||||
endpoint=grpc_endpoint,
|
||||
headers=headers,
|
||||
insecure=insecure,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
self.tracer_provider = TracerProvider(resource=self.resource)
|
||||
self.span_processor = BatchSpanProcessor(
|
||||
span_exporter=self.exporter,
|
||||
max_queue_size=max_queue_size,
|
||||
schedule_delay_millis=schedule_delay_sec * 1000,
|
||||
max_export_batch_size=max_export_batch_size,
|
||||
)
|
||||
self.tracer_provider.add_span_processor(self.span_processor)
|
||||
|
||||
self.tracer = self.tracer_provider.get_tracer("dify.tencent_apm")
|
||||
|
||||
# Store span contexts for parent-child relationships
|
||||
self.span_contexts: dict[int, trace_api.SpanContext] = {}
|
||||
|
||||
self.meter: Meter | None = None
|
||||
self.hist_llm_duration: Histogram | None = None
|
||||
self.metric_reader: MetricReader | None = None
|
||||
|
||||
# Metrics exporter and instruments
|
||||
try:
|
||||
from opentelemetry import metrics
|
||||
from opentelemetry.sdk.metrics import Histogram, MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
|
||||
|
||||
protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "").strip().lower()
|
||||
use_http_protobuf = protocol in {"http/protobuf", "http-protobuf"}
|
||||
use_http_json = protocol in {"http/json", "http-json"}
|
||||
|
||||
# Set preferred temporality for histograms to DELTA
|
||||
preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA}
|
||||
|
||||
def _create_metric_exporter(exporter_cls, **kwargs):
|
||||
"""Create metric exporter with preferred_temporality support"""
|
||||
try:
|
||||
return exporter_cls(**kwargs, preferred_temporality=preferred_temporality)
|
||||
except Exception:
|
||||
return exporter_cls(**kwargs)
|
||||
|
||||
metric_reader = None
|
||||
if use_http_json:
|
||||
exporter_cls = None
|
||||
for mod_path in (
|
||||
"opentelemetry.exporter.otlp.http.json.metric_exporter",
|
||||
"opentelemetry.exporter.otlp.json.metric_exporter",
|
||||
):
|
||||
try:
|
||||
mod = importlib.import_module(mod_path)
|
||||
exporter_cls = getattr(mod, "OTLPMetricExporter", None)
|
||||
if exporter_cls:
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if exporter_cls is not None:
|
||||
metric_exporter = _create_metric_exporter(
|
||||
exporter_cls,
|
||||
endpoint=endpoint,
|
||||
headers={"authorization": f"Bearer {token}"},
|
||||
)
|
||||
else:
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
|
||||
OTLPMetricExporter as HttpMetricExporter,
|
||||
)
|
||||
|
||||
metric_exporter = _create_metric_exporter(
|
||||
HttpMetricExporter,
|
||||
endpoint=endpoint,
|
||||
headers={"authorization": f"Bearer {token}"},
|
||||
)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
|
||||
)
|
||||
|
||||
elif use_http_protobuf:
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import (
|
||||
OTLPMetricExporter as HttpMetricExporter,
|
||||
)
|
||||
|
||||
metric_exporter = _create_metric_exporter(
|
||||
HttpMetricExporter,
|
||||
endpoint=endpoint,
|
||||
headers={"authorization": f"Bearer {token}"},
|
||||
)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
|
||||
)
|
||||
else:
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import (
|
||||
OTLPMetricExporter as GrpcMetricExporter,
|
||||
)
|
||||
|
||||
m_grpc_endpoint, m_insecure, _, _ = self._resolve_grpc_target(endpoint)
|
||||
|
||||
metric_exporter = _create_metric_exporter(
|
||||
GrpcMetricExporter,
|
||||
endpoint=m_grpc_endpoint,
|
||||
headers={"authorization": f"Bearer {token}"},
|
||||
insecure=m_insecure,
|
||||
)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
metric_exporter, export_interval_millis=self.metrics_export_interval_sec * 1000
|
||||
)
|
||||
|
||||
if metric_reader is not None:
|
||||
provider = MeterProvider(resource=self.resource, metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(provider)
|
||||
self.meter = metrics.get_meter("dify-sdk", dify_config.project.version)
|
||||
self.hist_llm_duration = self.meter.create_histogram(
|
||||
name=LLM_OPERATION_DURATION,
|
||||
unit="s",
|
||||
description="LLM operation duration (seconds)",
|
||||
)
|
||||
self.metric_reader = metric_reader
|
||||
else:
|
||||
self.meter = None
|
||||
self.hist_llm_duration = None
|
||||
self.metric_reader = None
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Metrics initialization failed; metrics disabled")
|
||||
self.meter = None
|
||||
self.hist_llm_duration = None
|
||||
self.metric_reader = None
|
||||
|
||||
def add_span(self, span_data: SpanData) -> None:
|
||||
"""Create and export span using OpenTelemetry Tracer API"""
|
||||
try:
|
||||
self._create_and_export_span(span_data)
|
||||
logger.debug("[Tencent APM] Created span: %s", span_data.name)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to create span: %s", span_data.name)
|
||||
|
||||
# Metrics recording API
|
||||
def record_llm_duration(self, latency_seconds: float, attributes: dict[str, str] | None = None) -> None:
|
||||
"""Record LLM operation duration histogram in seconds."""
|
||||
try:
|
||||
if not hasattr(self, "hist_llm_duration") or self.hist_llm_duration is None:
|
||||
return
|
||||
attrs: dict[str, str] = {}
|
||||
if attributes:
|
||||
for k, v in attributes.items():
|
||||
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
|
||||
self.hist_llm_duration.record(latency_seconds, attrs) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record LLM duration", exc_info=True)
|
||||
|
||||
def _create_and_export_span(self, span_data: SpanData) -> None:
|
||||
"""Create span using OpenTelemetry Tracer API"""
|
||||
try:
|
||||
parent_context = None
|
||||
if span_data.parent_span_id and span_data.parent_span_id in self.span_contexts:
|
||||
parent_context = trace_api.set_span_in_context(
|
||||
trace_api.NonRecordingSpan(self.span_contexts[span_data.parent_span_id])
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=span_data.name,
|
||||
context=parent_context,
|
||||
kind=SpanKind.INTERNAL,
|
||||
start_time=span_data.start_time,
|
||||
)
|
||||
self.span_contexts[span_data.span_id] = span.get_span_context()
|
||||
|
||||
if span_data.attributes:
|
||||
attributes: dict[str, AttributeValue] = {}
|
||||
for key, value in span_data.attributes.items():
|
||||
if isinstance(value, (int, float, bool)):
|
||||
attributes[key] = value
|
||||
else:
|
||||
attributes[key] = str(value)
|
||||
span.set_attributes(attributes)
|
||||
|
||||
if span_data.events:
|
||||
for event in span_data.events:
|
||||
span.add_event(event.name, event.attributes, event.timestamp)
|
||||
|
||||
if span_data.status:
|
||||
span.set_status(span_data.status)
|
||||
|
||||
# Manually end span; do not use context manager to avoid double-end warnings
|
||||
span.end(end_time=span_data.end_time)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Error creating span: %s", span_data.name)
|
||||
|
||||
def api_check(self) -> bool:
|
||||
"""Check API connectivity using socket connection test for gRPC endpoints"""
|
||||
try:
|
||||
# Resolve gRPC target consistently with exporters
|
||||
_, _, host, port = self._resolve_grpc_target(self.endpoint)
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(5)
|
||||
result = sock.connect_ex((host, port))
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
logger.info("[Tencent APM] Endpoint %s:%s is accessible", host, port)
|
||||
return True
|
||||
else:
|
||||
logger.warning("[Tencent APM] Endpoint %s:%s is not accessible", host, port)
|
||||
if host in ["127.0.0.1", "localhost"]:
|
||||
logger.info("[Tencent APM] Development environment detected, allowing config save")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] API check failed")
|
||||
if "127.0.0.1" in self.endpoint or "localhost" in self.endpoint:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_project_url(self) -> str:
|
||||
"""Get project console URL"""
|
||||
return "https://console.cloud.tencent.com/apm"
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the client and export remaining spans"""
|
||||
try:
|
||||
if self.span_processor:
|
||||
logger.info("[Tencent APM] Flushing remaining spans before shutdown")
|
||||
_ = self.span_processor.force_flush()
|
||||
self.span_processor.shutdown()
|
||||
|
||||
if self.tracer_provider:
|
||||
self.tracer_provider.shutdown()
|
||||
if self.metric_reader is not None:
|
||||
try:
|
||||
self.metric_reader.shutdown() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Error during client shutdown")
|
||||
|
||||
@staticmethod
|
||||
def _resolve_grpc_target(endpoint: str, default_port: int = 4317) -> tuple[str, bool, str, int]:
|
||||
"""Normalize endpoint to gRPC target and security flag.
|
||||
|
||||
Returns:
|
||||
(grpc_endpoint, insecure, host, port)
|
||||
"""
|
||||
try:
|
||||
if endpoint.startswith(("http://", "https://")):
|
||||
parsed = urlparse(endpoint)
|
||||
host = parsed.hostname or "localhost"
|
||||
port = parsed.port or default_port
|
||||
insecure = parsed.scheme == "http"
|
||||
return f"{host}:{port}", insecure, host, port
|
||||
|
||||
host = endpoint
|
||||
port = default_port
|
||||
if ":" in endpoint:
|
||||
parts = endpoint.rsplit(":", 1)
|
||||
host = parts[0] or "localhost"
|
||||
try:
|
||||
port = int(parts[1])
|
||||
except Exception:
|
||||
port = default_port
|
||||
|
||||
insecure = ("localhost" in host) or ("127.0.0.1" in host)
|
||||
return f"{host}:{port}", insecure, host, port
|
||||
except Exception:
|
||||
host, port = "localhost", default_port
|
||||
return f"{host}:{port}", True, host, port
|
||||
|
|
@ -0,0 +1 @@
|
|||
# Tencent trace entities module
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
from enum import Enum
|
||||
|
||||
# public
|
||||
GEN_AI_SESSION_ID = "gen_ai.session.id"
|
||||
|
||||
GEN_AI_USER_ID = "gen_ai.user.id"
|
||||
|
||||
GEN_AI_USER_NAME = "gen_ai.user.name"
|
||||
|
||||
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
|
||||
|
||||
GEN_AI_FRAMEWORK = "gen_ai.framework"
|
||||
|
||||
GEN_AI_IS_ENTRY = "gen_ai.is_entry" # mark to count the LLM-related traces
|
||||
|
||||
# Chain
|
||||
INPUT_VALUE = "gen_ai.entity.input"
|
||||
|
||||
OUTPUT_VALUE = "gen_ai.entity.output"
|
||||
|
||||
|
||||
# Retriever
|
||||
RETRIEVAL_QUERY = "retrieval.query"
|
||||
|
||||
RETRIEVAL_DOCUMENT = "retrieval.document"
|
||||
|
||||
|
||||
# GENERATION
|
||||
GEN_AI_MODEL_NAME = "gen_ai.response.model"
|
||||
|
||||
GEN_AI_PROVIDER = "gen_ai.provider.name"
|
||||
|
||||
|
||||
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
|
||||
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
|
||||
|
||||
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
|
||||
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
|
||||
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
|
||||
|
||||
GEN_AI_PROMPT = "gen_ai.prompt"
|
||||
|
||||
GEN_AI_COMPLETION = "gen_ai.completion"
|
||||
|
||||
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
|
||||
|
||||
# Tool
|
||||
TOOL_NAME = "tool.name"
|
||||
|
||||
TOOL_DESCRIPTION = "tool.description"
|
||||
|
||||
TOOL_PARAMETERS = "tool.parameters"
|
||||
|
||||
# Instrumentation Library
|
||||
INSTRUMENTATION_NAME = "dify-sdk"
|
||||
INSTRUMENTATION_VERSION = "0.1.0"
|
||||
INSTRUMENTATION_LANGUAGE = "python"
|
||||
|
||||
|
||||
# Metrics
|
||||
LLM_OPERATION_DURATION = "gen_ai.client.operation.duration"
|
||||
|
||||
|
||||
class GenAISpanKind(Enum):
|
||||
WORKFLOW = "WORKFLOW" # OpenLLMetry
|
||||
RETRIEVER = "RETRIEVER" # RAG
|
||||
GENERATION = "GENERATION" # Langfuse
|
||||
TOOL = "TOOL" # OpenLLMetry
|
||||
AGENT = "AGENT" # OpenLLMetry
|
||||
TASK = "TASK" # OpenLLMetry
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
from collections.abc import Sequence
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SpanData(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
trace_id: int = Field(..., description="The unique identifier for the trace.")
|
||||
parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.")
|
||||
span_id: int = Field(..., description="The unique identifier for this span.")
|
||||
name: str = Field(..., description="The name of the span.")
|
||||
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
|
||||
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
|
||||
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
|
||||
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||
start_time: int = Field(..., description="The start time of the span in nanoseconds.")
|
||||
end_time: int = Field(..., description="The end time of the span in nanoseconds.")
|
||||
|
|
@ -0,0 +1,372 @@
|
|||
"""
|
||||
Tencent APM Span Builder - handles all span construction logic
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.tencent_trace.entities.tencent_semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_IS_ENTRY,
|
||||
GEN_AI_MODEL_NAME,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_PROVIDER,
|
||||
GEN_AI_RESPONSE_FINISH_REASON,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_USAGE_INPUT_TOKENS,
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS,
|
||||
GEN_AI_USAGE_TOTAL_TOKENS,
|
||||
GEN_AI_USER_ID,
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
TOOL_NAME,
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
|
||||
from core.ops.tencent_trace.utils import TencentTraceUtils
|
||||
from core.rag.models.document import Document
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TencentSpanBuilder:
|
||||
"""Builder class for constructing different types of spans"""
|
||||
|
||||
@staticmethod
|
||||
def _get_time_nanoseconds(time_value: datetime | None) -> int:
|
||||
"""Convert datetime to nanoseconds for span creation."""
|
||||
return TencentTraceUtils.convert_datetime_to_nanoseconds(time_value)
|
||||
|
||||
@staticmethod
|
||||
def build_workflow_spans(
|
||||
trace_info: WorkflowTraceInfo, trace_id: int, user_id: str, links: list | None = None
|
||||
) -> list[SpanData]:
|
||||
"""Build workflow-related spans"""
|
||||
spans = []
|
||||
links = links or []
|
||||
|
||||
message_span_id = None
|
||||
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
|
||||
|
||||
if hasattr(trace_info, "metadata") and trace_info.metadata.get("conversation_id"):
|
||||
message_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "message")
|
||||
|
||||
status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
if message_span_id:
|
||||
message_span = TencentSpanBuilder._build_message_span(
|
||||
trace_info, trace_id, message_span_id, user_id, status, links
|
||||
)
|
||||
spans.append(message_span)
|
||||
|
||||
workflow_span = TencentSpanBuilder._build_workflow_span(
|
||||
trace_info, trace_id, workflow_span_id, message_span_id, user_id, status, links
|
||||
)
|
||||
spans.append(workflow_span)
|
||||
|
||||
return spans
|
||||
|
||||
@staticmethod
|
||||
def _build_message_span(
|
||||
trace_info: WorkflowTraceInfo, trace_id: int, message_span_id: int, user_id: str, status: Status, links: list
|
||||
) -> SpanData:
|
||||
"""Build message span for chatflow"""
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_IS_ENTRY: "true",
|
||||
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_workflow_span(
|
||||
trace_info: WorkflowTraceInfo,
|
||||
trace_id: int,
|
||||
workflow_span_id: int,
|
||||
message_span_id: int | None,
|
||||
user_id: str,
|
||||
status: Status,
|
||||
links: list,
|
||||
) -> SpanData:
|
||||
"""Build workflow span"""
|
||||
attributes = {
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
}
|
||||
|
||||
if message_span_id is None:
|
||||
attributes[GEN_AI_IS_ENTRY] = "true"
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=workflow_span_id,
|
||||
name="workflow",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes=attributes,
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_workflow_llm_span(
|
||||
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
"""Build LLM span for workflow nodes."""
|
||||
process_data = node_execution.process_data or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
|
||||
name="GENERATION",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
|
||||
GEN_AI_PROVIDER: process_data.get("model_provider", ""),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(outputs.get("text", "")),
|
||||
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
|
||||
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(outputs.get("text", "")),
|
||||
},
|
||||
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_message_span(
|
||||
trace_info: MessageTraceInfo, trace_id: int, user_id: str, links: list | None = None
|
||||
) -> SpanData:
|
||||
"""Build message span."""
|
||||
links = links or []
|
||||
status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message"),
|
||||
name="message",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_IS_ENTRY: "true",
|
||||
INPUT_VALUE: str(trace_info.inputs or ""),
|
||||
OUTPUT_VALUE: str(trace_info.outputs or ""),
|
||||
},
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
|
||||
"""Build tool span."""
|
||||
status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=parent_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "tool"),
|
||||
name=trace_info.tool_name,
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
TOOL_NAME: trace_info.tool_name,
|
||||
TOOL_DESCRIPTION: "",
|
||||
TOOL_PARAMETERS: json.dumps(trace_info.tool_parameters, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.tool_outputs),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_retrieval_span(trace_info: DatasetRetrievalTraceInfo, trace_id: int, parent_span_id: int) -> SpanData:
|
||||
"""Build dataset retrieval span."""
|
||||
status = Status(StatusCode.OK)
|
||||
if getattr(trace_info, "error", None):
|
||||
status = Status(StatusCode.ERROR, trace_info.error) # type: ignore[arg-type]
|
||||
|
||||
documents_data = TencentSpanBuilder._extract_retrieval_documents(trace_info.documents)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=parent_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "retrieval"),
|
||||
name="retrieval",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
RETRIEVAL_QUERY: str(trace_info.inputs or ""),
|
||||
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
|
||||
INPUT_VALUE: str(trace_info.inputs or ""),
|
||||
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
|
||||
"""Get workflow node execution status."""
|
||||
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return Status(StatusCode.OK)
|
||||
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
|
||||
return Status(StatusCode.ERROR, str(node_execution.error))
|
||||
return Status(StatusCode.UNSET)
|
||||
|
||||
@staticmethod
|
||||
def build_workflow_retrieval_span(
|
||||
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
"""Build knowledge retrieval span for workflow nodes."""
|
||||
input_value = ""
|
||||
if node_execution.inputs:
|
||||
input_value = str(node_execution.inputs.get("query", ""))
|
||||
output_value = ""
|
||||
if node_execution.outputs:
|
||||
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
RETRIEVAL_QUERY: input_value,
|
||||
RETRIEVAL_DOCUMENT: output_value,
|
||||
INPUT_VALUE: input_value,
|
||||
OUTPUT_VALUE: output_value,
|
||||
},
|
||||
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_workflow_tool_span(
|
||||
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
"""Build tool span for workflow nodes."""
|
||||
tool_des = {}
|
||||
if node_execution.metadata:
|
||||
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
TOOL_NAME: node_execution.title,
|
||||
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
},
|
||||
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_workflow_task_span(
|
||||
trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
"""Build generic task span for workflow nodes."""
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=TencentTraceUtils.convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
},
|
||||
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_retrieval_documents(documents: list[Document]):
|
||||
"""Extract documents data for retrieval tracing."""
|
||||
documents_data = []
|
||||
for document in documents:
|
||||
document_data = {
|
||||
"content": document.page_content,
|
||||
"metadata": {
|
||||
"dataset_id": document.metadata.get("dataset_id"),
|
||||
"doc_id": document.metadata.get("doc_id"),
|
||||
"document_id": document.metadata.get("document_id"),
|
||||
},
|
||||
"score": document.metadata.get("score"),
|
||||
}
|
||||
documents_data.append(document_data)
|
||||
return documents_data
|
||||
|
|
@ -0,0 +1,317 @@
|
|||
"""
|
||||
Tencent APM tracing implementation with separated concerns
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import TencentConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.tencent_trace.client import TencentTraceClient
|
||||
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
|
||||
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
|
||||
from core.ops.tencent_trace.utils import TencentTraceUtils
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from core.workflow.nodes import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TencentDataTrace(BaseTraceInstance):
|
||||
"""
|
||||
Tencent APM trace implementation with single responsibility principle.
|
||||
Acts as a coordinator that delegates specific tasks to specialized classes.
|
||||
"""
|
||||
|
||||
def __init__(self, tencent_config: TencentConfig):
|
||||
super().__init__(tencent_config)
|
||||
self.trace_client = TencentTraceClient(
|
||||
service_name=tencent_config.service_name,
|
||||
endpoint=tencent_config.endpoint,
|
||||
token=tencent_config.token,
|
||||
metrics_export_interval_sec=5,
|
||||
)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo) -> None:
|
||||
"""Main tracing entry point - coordinates different trace types."""
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
elif isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
elif isinstance(trace_info, ModerationTraceInfo):
|
||||
pass
|
||||
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
elif isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
elif isinstance(trace_info, GenerateNameTraceInfo):
|
||||
pass
|
||||
|
||||
def api_check(self) -> bool:
|
||||
return self.trace_client.api_check()
|
||||
|
||||
def get_project_url(self) -> str:
|
||||
return self.trace_client.get_project_url()
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo) -> None:
|
||||
"""Handle workflow tracing by coordinating data retrieval and span construction."""
|
||||
try:
|
||||
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.workflow_run_id)
|
||||
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
|
||||
|
||||
user_id = self._get_user_id(trace_info)
|
||||
|
||||
workflow_spans = TencentSpanBuilder.build_workflow_spans(trace_info, trace_id, str(user_id), links)
|
||||
|
||||
for span in workflow_spans:
|
||||
self.trace_client.add_span(span)
|
||||
|
||||
self._process_workflow_nodes(trace_info, trace_id)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process workflow trace")
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo) -> None:
|
||||
"""Handle message tracing."""
|
||||
try:
|
||||
trace_id = TencentTraceUtils.convert_to_trace_id(trace_info.message_id)
|
||||
user_id = self._get_user_id(trace_info)
|
||||
|
||||
links = []
|
||||
if trace_info.trace_id:
|
||||
links.append(TencentTraceUtils.create_link(trace_info.trace_id))
|
||||
|
||||
message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links)
|
||||
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process message trace")
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo) -> None:
|
||||
"""Handle tool tracing."""
|
||||
try:
|
||||
parent_span_id = None
|
||||
trace_root_id = None
|
||||
|
||||
if trace_info.message_id:
|
||||
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
|
||||
trace_root_id = trace_info.message_id
|
||||
|
||||
if parent_span_id and trace_root_id:
|
||||
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
|
||||
|
||||
tool_span = TencentSpanBuilder.build_tool_span(trace_info, trace_id, parent_span_id)
|
||||
|
||||
self.trace_client.add_span(tool_span)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process tool trace")
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo) -> None:
|
||||
"""Handle dataset retrieval tracing."""
|
||||
try:
|
||||
parent_span_id = None
|
||||
trace_root_id = None
|
||||
|
||||
if trace_info.message_id:
|
||||
parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message")
|
||||
trace_root_id = trace_info.message_id
|
||||
|
||||
if parent_span_id and trace_root_id:
|
||||
trace_id = TencentTraceUtils.convert_to_trace_id(trace_root_id)
|
||||
|
||||
retrieval_span = TencentSpanBuilder.build_retrieval_span(trace_info, trace_id, parent_span_id)
|
||||
|
||||
self.trace_client.add_span(retrieval_span)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process dataset retrieval trace")
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo) -> None:
|
||||
"""Handle suggested question tracing"""
|
||||
try:
|
||||
logger.info("[Tencent APM] Processing suggested question trace")
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process suggested question trace")
|
||||
|
||||
def _process_workflow_nodes(self, trace_info: WorkflowTraceInfo, trace_id: int) -> None:
|
||||
"""Process workflow node executions."""
|
||||
try:
|
||||
workflow_span_id = TencentTraceUtils.convert_to_span_id(trace_info.workflow_run_id, "workflow")
|
||||
|
||||
node_executions = self._get_workflow_node_executions(trace_info)
|
||||
|
||||
for node_execution in node_executions:
|
||||
try:
|
||||
node_span = self._build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
|
||||
if node_span:
|
||||
self.trace_client.add_span(node_span)
|
||||
|
||||
if node_execution.node_type == NodeType.LLM:
|
||||
self._record_llm_metrics(node_execution)
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to process workflow nodes")
|
||||
|
||||
def _build_workflow_node_span(
|
||||
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
|
||||
) -> SpanData | None:
|
||||
"""Build span for different node types"""
|
||||
try:
|
||||
if node_execution.node_type == NodeType.LLM:
|
||||
return TencentSpanBuilder.build_workflow_llm_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return TencentSpanBuilder.build_workflow_retrieval_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
elif node_execution.node_type == NodeType.TOOL:
|
||||
return TencentSpanBuilder.build_workflow_tool_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
else:
|
||||
# Handle all other node types as generic tasks
|
||||
return TencentSpanBuilder.build_workflow_task_span(
|
||||
trace_id, workflow_span_id, trace_info, node_execution
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"[Tencent APM] Error building span for node %s: %s",
|
||||
node_execution.id,
|
||||
node_execution.node_type,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> list[WorkflowNodeExecution]:
|
||||
"""Retrieve workflow node executions from database."""
|
||||
try:
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app_stmt = select(App).where(App.id == app_id)
|
||||
app = session.scalar(app_stmt)
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator")
|
||||
|
||||
account_stmt = select(Account).where(Account.id == app.created_by)
|
||||
service_account = session.scalar(account_stmt)
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account not found for app {app_id}")
|
||||
|
||||
current_tenant = (
|
||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
||||
)
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||
|
||||
service_account.set_tenant_id(current_tenant.tenant_id)
|
||||
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_maker,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
|
||||
return list(executions)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to get workflow node executions")
|
||||
return []
|
||||
|
||||
def _get_user_id(self, trace_info: BaseTraceInfo) -> str:
|
||||
"""Get user ID from trace info."""
|
||||
try:
|
||||
tenant_id = None
|
||||
user_id = None
|
||||
|
||||
if isinstance(trace_info, (WorkflowTraceInfo, GenerateNameTraceInfo)):
|
||||
tenant_id = trace_info.tenant_id
|
||||
|
||||
if hasattr(trace_info, "metadata") and trace_info.metadata:
|
||||
user_id = trace_info.metadata.get("user_id")
|
||||
|
||||
if user_id and tenant_id:
|
||||
stmt = (
|
||||
select(Account.name)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.where(Account.id == user_id, TenantAccountJoin.tenant_id == tenant_id)
|
||||
)
|
||||
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
with session_maker() as session:
|
||||
account_name = session.scalar(stmt)
|
||||
return account_name or str(user_id)
|
||||
elif user_id:
|
||||
return str(user_id)
|
||||
|
||||
return "anonymous"
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to get user ID")
|
||||
return "unknown"
|
||||
|
||||
def _record_llm_metrics(self, node_execution: WorkflowNodeExecution) -> None:
|
||||
"""Record LLM performance metrics"""
|
||||
try:
|
||||
if not hasattr(self.trace_client, "record_llm_duration"):
|
||||
return
|
||||
|
||||
process_data = node_execution.process_data or {}
|
||||
usage = process_data.get("usage", {})
|
||||
latency_s = float(usage.get("latency", 0.0))
|
||||
|
||||
if latency_s > 0:
|
||||
attributes = {
|
||||
"provider": process_data.get("model_provider", ""),
|
||||
"model": process_data.get("model_name", ""),
|
||||
"span_kind": "GENERATION",
|
||||
}
|
||||
self.trace_client.record_llm_duration(latency_s, attributes)
|
||||
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record LLM metrics")
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure proper cleanup on garbage collection."""
|
||||
try:
|
||||
if hasattr(self, "trace_client"):
|
||||
self.trace_client.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
"""
|
||||
Utility functions for Tencent APM tracing
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import random
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
|
||||
class TencentTraceUtils:
|
||||
"""Utility class for common tracing operations."""
|
||||
|
||||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
||||
@staticmethod
|
||||
def convert_to_trace_id(uuid_v4: str | None) -> int:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
return cast(int, uuid_obj.int)
|
||||
|
||||
@staticmethod
|
||||
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
combined_key = f"{uuid_obj.hex}-{span_type}"
|
||||
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
|
||||
return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||
|
||||
@staticmethod
|
||||
def generate_span_id() -> int:
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == TencentTraceUtils.INVALID_SPAN_ID:
|
||||
span_id = random.getrandbits(64)
|
||||
return span_id
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime_to_nanoseconds(start_time: datetime | None) -> int:
|
||||
if start_time is None:
|
||||
start_time = datetime.now()
|
||||
timestamp_in_seconds = start_time.timestamp()
|
||||
return int(timestamp_in_seconds * 1e9)
|
||||
|
||||
@staticmethod
|
||||
def create_link(trace_id_str: str) -> Link:
|
||||
try:
|
||||
trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else cast(int, uuid.UUID(trace_id_str).int)
|
||||
except (ValueError, TypeError):
|
||||
trace_id = cast(int, uuid.uuid4().int)
|
||||
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=TencentTraceUtils.INVALID_SPAN_ID,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
)
|
||||
return Link(span_context)
|
||||
|
|
@ -2,7 +2,7 @@ import inspect
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -31,6 +31,17 @@ from core.plugin.impl.exc import (
|
|||
)
|
||||
|
||||
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
||||
_plugin_daemon_timeout_config = cast(
|
||||
float | httpx.Timeout | None,
|
||||
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
|
||||
)
|
||||
plugin_daemon_request_timeout: httpx.Timeout | None
|
||||
if _plugin_daemon_timeout_config is None:
|
||||
plugin_daemon_request_timeout = None
|
||||
elif isinstance(_plugin_daemon_timeout_config, httpx.Timeout):
|
||||
plugin_daemon_request_timeout = _plugin_daemon_timeout_config
|
||||
else:
|
||||
plugin_daemon_request_timeout = httpx.Timeout(_plugin_daemon_timeout_config)
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
|
|
@ -58,6 +69,7 @@ class BasePluginClient:
|
|||
"headers": headers,
|
||||
"params": params,
|
||||
"files": files,
|
||||
"timeout": plugin_daemon_request_timeout,
|
||||
}
|
||||
if isinstance(prepared_data, dict):
|
||||
request_kwargs["data"] = prepared_data
|
||||
|
|
@ -116,6 +128,7 @@ class BasePluginClient:
|
|||
"headers": headers,
|
||||
"params": params,
|
||||
"files": files,
|
||||
"timeout": plugin_daemon_request_timeout,
|
||||
}
|
||||
if isinstance(prepared_data, dict):
|
||||
stream_kwargs["data"] = prepared_data
|
||||
|
|
@ -255,7 +268,7 @@ class BasePluginClient:
|
|||
except Exception:
|
||||
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
|
||||
|
||||
logger.error("Error in stream reponse for plugin %s", rep.__dict__)
|
||||
logger.error("Error in stream response for plugin %s", rep.__dict__)
|
||||
self._handle_plugin_daemon_error(error.error_type, error.message)
|
||||
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
|
||||
if rep.data is None:
|
||||
|
|
|
|||
|
|
@ -1046,7 +1046,7 @@ class ProviderManager:
|
|||
"""
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type.value == FormType.SECRET_INPUT:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class RetrievalService:
|
|||
@classmethod
|
||||
def retrieve(
|
||||
cls,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
|
|
@ -56,7 +56,7 @@ class RetrievalService:
|
|||
# Optimize multithreading with thread pools
|
||||
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
||||
futures = []
|
||||
if retrieval_method == "keyword_search":
|
||||
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.keyword_search,
|
||||
|
|
@ -220,7 +220,7 @@ class RetrievalService:
|
|||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
exceptions: list,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -43,8 +43,7 @@ class CacheEmbedding(Embeddings):
|
|||
else:
|
||||
embedding_queue_indices.append(i)
|
||||
|
||||
# release database connection, because embedding may take a long time
|
||||
db.session.close()
|
||||
# NOTE: avoid closing the shared scoped session here; downstream code may still have pending work
|
||||
|
||||
if embedding_queue_indices:
|
||||
embedding_queue_texts = [texts[i] for i in embedding_queue_indices]
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DatasourceStreamEvent(Enum):
|
||||
class DatasourceStreamEvent(StrEnum):
|
||||
"""
|
||||
Datasource Stream event
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -49,7 +50,8 @@ class UnstructuredWordExtractor(BaseExtractor):
|
|||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import logging
|
|||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -46,7 +47,8 @@ class UnstructuredEmailExtractor(BaseExtractor):
|
|||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import logging
|
|||
|
||||
import pypandoc # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -40,7 +41,8 @@ class UnstructuredEpubExtractor(BaseExtractor):
|
|||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -32,7 +33,8 @@ class UnstructuredMarkdownExtractor(BaseExtractor):
|
|||
elements = partition_md(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -31,7 +32,8 @@ class UnstructuredMsgExtractor(BaseExtractor):
|
|||
elements = partition_msg(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -32,7 +33,8 @@ class UnstructuredXmlExtractor(BaseExtractor):
|
|||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||
from configs import dify_config
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.rag.splitter.fixed_text_splitter import (
|
||||
EnhanceRecursiveCharacterTextSplitter,
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
|
|
@ -49,7 +50,7 @@ class BaseIndexProcessor(ABC):
|
|||
@abstractmethod
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
|
|
@ -106,7 +107,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
|
|
@ -161,7 +162,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document, QAStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
|
@ -141,7 +142,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -364,7 +364,7 @@ class DatasetRetrieval:
|
|||
top_k = retrieval_model_config["top_k"]
|
||||
# get retrieval method
|
||||
if dataset.indexing_technique == "economy":
|
||||
retrieval_method = "keyword_search"
|
||||
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
else:
|
||||
retrieval_method = retrieval_model_config["search_method"]
|
||||
# get reranking model
|
||||
|
|
@ -623,7 +623,7 @@ class DatasetRetrieval:
|
|||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search",
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class RetrievalMethod(Enum):
|
||||
class RetrievalMethod(StrEnum):
|
||||
SEMANTIC_SEARCH = "semantic_search"
|
||||
FULL_TEXT_SEARCH = "full_text_search"
|
||||
HYBRID_SEARCH = "hybrid_search"
|
||||
|
|
|
|||
|
|
@ -76,7 +76,8 @@ class MCPToolProviderController(ToolProviderController):
|
|||
)
|
||||
for remote_mcp_tool in remote_mcp_tools
|
||||
]
|
||||
|
||||
if not db_provider.icon:
|
||||
raise ValueError("Database provider icon is required")
|
||||
return cls(
|
||||
entity=ToolProviderEntityWithPlugin(
|
||||
identity=ToolProviderIdentity(
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search",
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search",
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models.
|
|||
"""
|
||||
|
||||
import json
|
||||
from decimal import Decimal
|
||||
from typing import cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
|
|
@ -118,10 +119,10 @@ class ModelInvocationUtils:
|
|||
model_response="",
|
||||
prompt_tokens=prompt_tokens,
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
answer_unit_price=Decimal(),
|
||||
answer_price_unit=Decimal(),
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
total_price=Decimal(),
|
||||
currency="USD",
|
||||
)
|
||||
|
||||
|
|
@ -152,7 +153,7 @@ class ModelInvocationUtils:
|
|||
raise InvokeModelError(f"Invoke error: {e}")
|
||||
|
||||
# update tool model invoke
|
||||
tool_model_invoke.model_response = response.message.content
|
||||
tool_model_invoke.model_response = str(response.message.content)
|
||||
if response.usage:
|
||||
tool_model_invoke.answer_tokens = response.usage.completion_tokens
|
||||
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class NodeState(Enum):
|
||||
class NodeState(StrEnum):
|
||||
"""State of a node or edge during workflow execution."""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ When limits are exceeded, the layer automatically aborts execution.
|
|||
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
|
@ -24,7 +24,7 @@ from core.workflow.graph_events import (
|
|||
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
|
||||
|
||||
class LimitType(Enum):
|
||||
class LimitType(StrEnum):
|
||||
"""Types of execution limits that can be exceeded."""
|
||||
|
||||
STEP_LIMIT = "step_limit"
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Literal, Union
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
|
|
@ -63,7 +64,7 @@ class RetrievalSetting(BaseModel):
|
|||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"]
|
||||
search_method: RetrievalMethod
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import logging
|
||||
import time as time_module
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -267,7 +268,7 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
|
|||
|
||||
# Build and execute the update statement
|
||||
stmt = update(Provider).where(*where_conditions).values(**update_values)
|
||||
result = session.execute(stmt)
|
||||
result = cast(CursorResult, session.execute(stmt))
|
||||
rows_affected = result.rowcount
|
||||
|
||||
logger.debug(
|
||||
|
|
|
|||
|
|
@ -64,7 +64,10 @@ def build_from_mapping(
|
|||
config: FileUploadConfig | None = None,
|
||||
strict_type_validation: bool = False,
|
||||
) -> File:
|
||||
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
|
||||
transfer_method_value = mapping.get("transfer_method")
|
||||
if not transfer_method_value:
|
||||
raise ValueError("transfer_method is required in file mapping")
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
|
||||
build_functions: dict[FileTransferMethod, Callable] = {
|
||||
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
|
||||
|
|
@ -104,6 +107,8 @@ def build_from_mappings(
|
|||
) -> Sequence[File]:
|
||||
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
|
||||
# Implement batch processing to reduce database load when handling multiple files.
|
||||
# Filter out None/empty mappings to avoid errors
|
||||
valid_mappings = [m for m in mappings if m and m.get("transfer_method")]
|
||||
files = [
|
||||
build_from_mapping(
|
||||
mapping=mapping,
|
||||
|
|
@ -111,7 +116,7 @@ def build_from_mappings(
|
|||
config=config,
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
for mapping in mappings
|
||||
for mapping in valid_mappings
|
||||
]
|
||||
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -13,6 +13,15 @@ from models.model import EndUser
|
|||
#: A proxy for the current user. If no user is logged in, this will be an
|
||||
#: anonymous user
|
||||
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
|
||||
|
||||
|
||||
def current_account_with_tenant():
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
||||
return current_user, current_user.current_tenant_id
|
||||
|
||||
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
|
|
|||
|
|
@ -37,10 +37,11 @@ config.set_main_option('sqlalchemy.url', get_engine_url())
|
|||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
from models.base import Base
|
||||
from models.base import TypeBase
|
||||
|
||||
|
||||
def get_metadata():
|
||||
return Base.metadata
|
||||
return TypeBase.metadata
|
||||
|
||||
def include_object(object, name, type_, reflected, compare_to):
|
||||
if type_ == "foreign_key_constraint":
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@ from sqlalchemy import DateTime, String, func
|
|||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from models.base import Base
|
||||
from models.base import TypeBase
|
||||
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class DataSourceOauthBinding(Base):
|
||||
class DataSourceOauthBinding(TypeBase):
|
||||
__tablename__ = "data_source_oauth_bindings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="source_binding_pkey"),
|
||||
|
|
@ -19,17 +19,25 @@ class DataSourceOauthBinding(Base):
|
|||
sa.Index("source_info_idx", "source_info", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
source_info = mapped_column(JSONB, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
source_info: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
|
||||
|
||||
|
||||
class DataSourceApiKeyAuthBinding(Base):
|
||||
class DataSourceApiKeyAuthBinding(TypeBase):
|
||||
__tablename__ = "data_source_api_key_auth_bindings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
|
||||
|
|
@ -37,14 +45,22 @@ class DataSourceApiKeyAuthBinding(Base):
|
|||
sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
credentials = mapped_column(sa.Text, nullable=True) # JSON
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
credentials: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None) # JSON
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
|
|
@ -52,7 +68,7 @@ class DataSourceApiKeyAuthBinding(Base):
|
|||
"tenant_id": self.tenant_id,
|
||||
"category": self.category,
|
||||
"provider": self.provider,
|
||||
"credentials": json.loads(self.credentials),
|
||||
"credentials": json.loads(self.credentials) if self.credentials else None,
|
||||
"created_at": self.created_at.timestamp(),
|
||||
"updated_at": self.updated_at.timestamp(),
|
||||
"disabled": self.disabled,
|
||||
|
|
|
|||
|
|
@ -6,41 +6,43 @@ from sqlalchemy import DateTime, String
|
|||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.base import Base
|
||||
from models.base import TypeBase
|
||||
|
||||
|
||||
class CeleryTask(Base):
|
||||
class CeleryTask(TypeBase):
|
||||
"""Task result/status."""
|
||||
|
||||
__tablename__ = "celery_taskmeta"
|
||||
|
||||
id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
|
||||
task_id = mapped_column(String(155), unique=True)
|
||||
status = mapped_column(String(50), default=states.PENDING)
|
||||
result = mapped_column(sa.PickleType, nullable=True)
|
||||
date_done = mapped_column(
|
||||
id: Mapped[int] = mapped_column(
|
||||
sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True, init=False
|
||||
)
|
||||
task_id: Mapped[str] = mapped_column(String(155), unique=True)
|
||||
status: Mapped[str] = mapped_column(String(50), default=states.PENDING)
|
||||
result: Mapped[bytes | None] = mapped_column(sa.PickleType, nullable=True, default=None)
|
||||
date_done: Mapped[datetime | None] = mapped_column(
|
||||
DateTime,
|
||||
default=lambda: naive_utc_now(),
|
||||
onupdate=lambda: naive_utc_now(),
|
||||
default=naive_utc_now,
|
||||
onupdate=naive_utc_now,
|
||||
nullable=True,
|
||||
)
|
||||
traceback = mapped_column(sa.Text, nullable=True)
|
||||
name = mapped_column(String(155), nullable=True)
|
||||
args = mapped_column(sa.LargeBinary, nullable=True)
|
||||
kwargs = mapped_column(sa.LargeBinary, nullable=True)
|
||||
worker = mapped_column(String(155), nullable=True)
|
||||
retries: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
queue = mapped_column(String(155), nullable=True)
|
||||
traceback: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
name: Mapped[str | None] = mapped_column(String(155), nullable=True, default=None)
|
||||
args: Mapped[bytes | None] = mapped_column(sa.LargeBinary, nullable=True, default=None)
|
||||
kwargs: Mapped[bytes | None] = mapped_column(sa.LargeBinary, nullable=True, default=None)
|
||||
worker: Mapped[str | None] = mapped_column(String(155), nullable=True, default=None)
|
||||
retries: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
|
||||
queue: Mapped[str | None] = mapped_column(String(155), nullable=True, default=None)
|
||||
|
||||
|
||||
class CeleryTaskSet(Base):
|
||||
class CeleryTaskSet(TypeBase):
|
||||
"""TaskSet result."""
|
||||
|
||||
__tablename__ = "celery_tasksetmeta"
|
||||
|
||||
id: Mapped[int] = mapped_column(
|
||||
sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True
|
||||
sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True, init=False
|
||||
)
|
||||
taskset_id = mapped_column(String(155), unique=True)
|
||||
result = mapped_column(sa.PickleType, nullable=True)
|
||||
date_done: Mapped[datetime | None] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True)
|
||||
taskset_id: Mapped[str] = mapped_column(String(155), unique=True)
|
||||
result: Mapped[bytes | None] = mapped_column(sa.PickleType, nullable=True, default=None)
|
||||
date_done: Mapped[datetime | None] = mapped_column(DateTime, default=naive_utc_now, nullable=True)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
|
@ -13,7 +14,7 @@ from core.helper import encrypter
|
|||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
from models.base import Base, TypeBase
|
||||
from models.base import TypeBase
|
||||
|
||||
from .engine import db
|
||||
from .model import Account, App, Tenant
|
||||
|
|
@ -42,28 +43,28 @@ class ToolOAuthSystemClient(TypeBase):
|
|||
|
||||
|
||||
# tenant level tool oauth client params (client_id, client_secret, etc.)
|
||||
class ToolOAuthTenantClient(Base):
|
||||
class ToolOAuthTenantClient(TypeBase):
|
||||
__tablename__ = "tool_oauth_tenant_clients"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), init=False)
|
||||
# oauth params of the tool provider
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False, init=False)
|
||||
|
||||
@property
|
||||
def oauth_params(self) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
|
||||
|
||||
|
||||
class BuiltinToolProvider(Base):
|
||||
class BuiltinToolProvider(TypeBase):
|
||||
"""
|
||||
This table stores the tool provider information for built-in tools for each tenant.
|
||||
"""
|
||||
|
|
@ -75,37 +76,45 @@ class BuiltinToolProvider(Base):
|
|||
)
|
||||
|
||||
# id of the tool provider
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
name: Mapped[str] = mapped_column(
|
||||
String(256), nullable=False, server_default=sa.text("'API KEY 1'::character varying")
|
||||
String(256),
|
||||
nullable=False,
|
||||
server_default=sa.text("'API KEY 1'::character varying"),
|
||||
)
|
||||
# id of the tenant
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
# who created this tool provider
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# name of the tool provider
|
||||
provider: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
# credential of the tool provider
|
||||
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
|
||||
encrypted_credentials: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text("CURRENT_TIMESTAMP(0)"),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
|
||||
# credential type, e.g., "api-key", "oauth2"
|
||||
credential_type: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, server_default=sa.text("'api-key'::character varying")
|
||||
String(32), nullable=False, server_default=sa.text("'api-key'::character varying"), default="api-key"
|
||||
)
|
||||
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"))
|
||||
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"), default=-1)
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
if not self.encrypted_credentials:
|
||||
return {}
|
||||
return cast(dict[str, Any], json.loads(self.encrypted_credentials))
|
||||
|
||||
|
||||
class ApiToolProvider(Base):
|
||||
class ApiToolProvider(TypeBase):
|
||||
"""
|
||||
The table stores the api providers.
|
||||
"""
|
||||
|
|
@ -116,31 +125,43 @@ class ApiToolProvider(Base):
|
|||
sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# name of the api provider
|
||||
name = mapped_column(String(255), nullable=False, server_default=sa.text("'API KEY 1'::character varying"))
|
||||
name: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'API KEY 1'::character varying"),
|
||||
)
|
||||
# icon
|
||||
icon: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# original schema
|
||||
schema = mapped_column(sa.Text, nullable=False)
|
||||
schema: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# who created this tool
|
||||
user_id = mapped_column(StringUUID, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# description of the provider
|
||||
description = mapped_column(sa.Text, nullable=False)
|
||||
description: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# json format tools
|
||||
tools_str = mapped_column(sa.Text, nullable=False)
|
||||
tools_str: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# json format credentials
|
||||
credentials_str = mapped_column(sa.Text, nullable=False)
|
||||
credentials_str: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# privacy policy
|
||||
privacy_policy = mapped_column(String(255), nullable=True)
|
||||
privacy_policy: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
# custom_disclaimer
|
||||
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def schema_type(self) -> "ApiProviderSchemaType":
|
||||
|
|
@ -189,7 +210,7 @@ class ToolLabelBinding(TypeBase):
|
|||
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
|
||||
|
||||
class WorkflowToolProvider(Base):
|
||||
class WorkflowToolProvider(TypeBase):
|
||||
"""
|
||||
The table stores the workflow providers.
|
||||
"""
|
||||
|
|
@ -201,7 +222,7 @@ class WorkflowToolProvider(Base):
|
|||
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# name of the workflow provider
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# label of the workflow provider
|
||||
|
|
@ -219,15 +240,19 @@ class WorkflowToolProvider(Base):
|
|||
# description of the provider
|
||||
description: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# parameter configuration
|
||||
parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]")
|
||||
parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]", default="[]")
|
||||
# privacy policy
|
||||
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="")
|
||||
privacy_policy: Mapped[str | None] = mapped_column(String(255), nullable=True, server_default="", default=None)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text("CURRENT_TIMESTAMP(0)"),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
@ -252,7 +277,7 @@ class WorkflowToolProvider(Base):
|
|||
return db.session.query(App).where(App.id == self.app_id).first()
|
||||
|
||||
|
||||
class MCPToolProvider(Base):
|
||||
class MCPToolProvider(TypeBase):
|
||||
"""
|
||||
The table stores the mcp providers.
|
||||
"""
|
||||
|
|
@ -265,7 +290,7 @@ class MCPToolProvider(Base):
|
|||
sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# name of the mcp provider
|
||||
name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# server identifier of the mcp provider
|
||||
|
|
@ -275,27 +300,33 @@ class MCPToolProvider(Base):
|
|||
# hash of server_url for uniqueness check
|
||||
server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# icon of the mcp provider
|
||||
icon: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
icon: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# who created this tool
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# encrypted credentials
|
||||
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
|
||||
encrypted_credentials: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
# authed
|
||||
authed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
||||
# tools
|
||||
tools: Mapped[str] = mapped_column(sa.Text, nullable=False, default="[]")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text("CURRENT_TIMESTAMP(0)"),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"), default=30.0)
|
||||
sse_read_timeout: Mapped[float] = mapped_column(
|
||||
sa.Float, nullable=False, server_default=sa.text("300"), default=300.0
|
||||
)
|
||||
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
|
||||
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
|
||||
# encrypted headers for MCP server requests
|
||||
encrypted_headers: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
encrypted_headers: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
|
|
@ -306,9 +337,11 @@ class MCPToolProvider(Base):
|
|||
|
||||
@property
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
if not self.encrypted_credentials:
|
||||
return {}
|
||||
try:
|
||||
return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
|
||||
except Exception:
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
@property
|
||||
|
|
@ -321,6 +354,7 @@ class MCPToolProvider(Base):
|
|||
def provider_icon(self) -> Mapping[str, str] | str:
|
||||
from core.file import helpers as file_helpers
|
||||
|
||||
assert self.icon
|
||||
try:
|
||||
return json.loads(self.icon)
|
||||
except json.JSONDecodeError:
|
||||
|
|
@ -419,7 +453,7 @@ class MCPToolProvider(Base):
|
|||
return encrypter.decrypt(self.credentials)
|
||||
|
||||
|
||||
class ToolModelInvoke(Base):
|
||||
class ToolModelInvoke(TypeBase):
|
||||
"""
|
||||
store the invoke logs from tool invoke
|
||||
"""
|
||||
|
|
@ -427,37 +461,47 @@ class ToolModelInvoke(Base):
|
|||
__tablename__ = "tool_model_invokes"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# who invoke this tool
|
||||
user_id = mapped_column(StringUUID, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# provider
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# type
|
||||
tool_type = mapped_column(String(40), nullable=False)
|
||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# tool name
|
||||
tool_name = mapped_column(String(128), nullable=False)
|
||||
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
# invoke parameters
|
||||
model_parameters = mapped_column(sa.Text, nullable=False)
|
||||
model_parameters: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# prompt messages
|
||||
prompt_messages = mapped_column(sa.Text, nullable=False)
|
||||
prompt_messages: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# invoke response
|
||||
model_response = mapped_column(sa.Text, nullable=False)
|
||||
model_response: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
prompt_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_price = mapped_column(sa.Numeric(10, 7))
|
||||
answer_unit_price: Mapped[Decimal] = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit: Mapped[Decimal] = mapped_column(
|
||||
sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")
|
||||
)
|
||||
provider_response_latency: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric(10, 7))
|
||||
currency: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
|
||||
@deprecated
|
||||
class ToolConversationVariables(Base):
|
||||
class ToolConversationVariables(TypeBase):
|
||||
"""
|
||||
store the conversation variables from tool invoke
|
||||
"""
|
||||
|
|
@ -470,18 +514,26 @@ class ToolConversationVariables(Base):
|
|||
sa.Index("conversation_id_idx", "conversation_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# conversation user id
|
||||
user_id = mapped_column(StringUUID, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# conversation id
|
||||
conversation_id = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# variables pool
|
||||
variables_str = mapped_column(sa.Text, nullable=False)
|
||||
variables_str: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
|
|
@ -519,7 +571,7 @@ class ToolFile(TypeBase):
|
|||
|
||||
|
||||
@deprecated
|
||||
class DeprecatedPublishedAppTool(Base):
|
||||
class DeprecatedPublishedAppTool(TypeBase):
|
||||
"""
|
||||
The table stores the apps published as a tool for each person.
|
||||
"""
|
||||
|
|
@ -530,26 +582,34 @@ class DeprecatedPublishedAppTool(Base):
|
|||
sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# id of the app
|
||||
app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
|
||||
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# who published this tool
|
||||
description = mapped_column(sa.Text, nullable=False)
|
||||
description: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# llm_description of the tool, for LLM
|
||||
llm_description = mapped_column(sa.Text, nullable=False)
|
||||
llm_description: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# query description, query will be seem as a parameter of the tool,
|
||||
# to describe this parameter to llm, we need this field
|
||||
query_description = mapped_column(sa.Text, nullable=False)
|
||||
query_description: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# query name, the name of the query parameter
|
||||
query_name = mapped_column(String(40), nullable=False)
|
||||
query_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# name of the tool provider
|
||||
tool_name = mapped_column(String(40), nullable=False)
|
||||
tool_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# author
|
||||
author = mapped_column(String(40), nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
author: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text("CURRENT_TIMESTAMP(0)"),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def description_i18n(self) -> "I18nObject":
|
||||
|
|
|
|||
|
|
@ -4,46 +4,58 @@ import sqlalchemy as sa
|
|||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from models.base import Base
|
||||
from models.base import TypeBase
|
||||
|
||||
from .engine import db
|
||||
from .model import Message
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class SavedMessage(Base):
|
||||
class SavedMessage(TypeBase):
|
||||
__tablename__ = "saved_messages"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="saved_message_pkey"),
|
||||
sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
created_by_role = mapped_column(
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_by_role: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return db.session.query(Message).where(Message.id == self.message_id).first()
|
||||
|
||||
|
||||
class PinnedConversation(Base):
|
||||
class PinnedConversation(TypeBase):
|
||||
__tablename__ = "pinned_conversations"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
|
||||
sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID)
|
||||
created_by_role = mapped_column(
|
||||
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
|
||||
created_by_role: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'end_user'::character varying"),
|
||||
)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ dependencies = [
|
|||
"celery~=5.5.2",
|
||||
"chardet~=5.1.0",
|
||||
"flask~=3.1.2",
|
||||
"flask-compress~=1.17",
|
||||
"flask-compress>=1.17,<1.18",
|
||||
"flask-cors~=6.0.0",
|
||||
"flask-login~=0.6.3",
|
||||
"flask-migrate~=4.0.7",
|
||||
|
|
@ -36,7 +36,7 @@ dependencies = [
|
|||
"markdown~=3.5.1",
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.7.25",
|
||||
"opik~=1.8.72",
|
||||
"opentelemetry-api==1.27.0",
|
||||
"opentelemetry-distro==0.48b0",
|
||||
"opentelemetry-exporter-otlp==1.27.0",
|
||||
|
|
|
|||
|
|
@ -7,8 +7,10 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
|||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import asc, delete, desc, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
|
@ -181,7 +183,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
|||
|
||||
# Delete the batch
|
||||
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||
result = session.execute(delete_stmt)
|
||||
result = cast(CursorResult, session.execute(delete_stmt))
|
||||
session.commit()
|
||||
total_deleted += result.rowcount
|
||||
|
||||
|
|
@ -228,7 +230,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
|||
|
||||
# Delete the batch
|
||||
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||
result = session.execute(delete_stmt)
|
||||
result = cast(CursorResult, session.execute(delete_stmt))
|
||||
session.commit()
|
||||
total_deleted += result.rowcount
|
||||
|
||||
|
|
@ -285,6 +287,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
|||
|
||||
with self._session_maker() as session:
|
||||
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||
result = session.execute(stmt)
|
||||
result = cast(CursorResult, session.execute(stmt))
|
||||
session.commit()
|
||||
return result.rowcount
|
||||
|
|
|
|||
|
|
@ -22,8 +22,10 @@ Implementation Notes:
|
|||
import logging
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
|
|
@ -150,7 +152,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
|
||||
with self._session_maker() as session:
|
||||
stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
|
||||
result = session.execute(stmt)
|
||||
result = cast(CursorResult, session.execute(stmt))
|
||||
session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
|
|
@ -186,7 +188,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
|
||||
# Delete the batch
|
||||
delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
|
||||
result = session.execute(delete_stmt)
|
||||
result = cast(CursorResult, session.execute(delete_stmt))
|
||||
session.commit()
|
||||
|
||||
batch_deleted = result.rowcount
|
||||
|
|
|
|||
|
|
@ -8,8 +8,7 @@ from werkzeug.exceptions import NotFound
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
|
||||
|
|
@ -24,10 +23,10 @@ class AppAnnotationService:
|
|||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -63,12 +62,12 @@ class AppAnnotationService:
|
|||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
assert current_user.current_tenant_id is not None
|
||||
assert current_tenant_id is not None
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
args["question"],
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
)
|
||||
|
|
@ -86,13 +85,12 @@ class AppAnnotationService:
|
|||
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
|
||||
# send batch add segments task
|
||||
redis_client.setnx(enable_app_annotation_job_key, "waiting")
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
enable_annotation_reply_task.delay(
|
||||
str(job_id),
|
||||
app_id,
|
||||
current_user.id,
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
args["score_threshold"],
|
||||
args["embedding_provider_name"],
|
||||
args["embedding_model_name"],
|
||||
|
|
@ -101,8 +99,7 @@ class AppAnnotationService:
|
|||
|
||||
@classmethod
|
||||
def disable_app_annotation(cls, app_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(disable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
|
|
@ -113,17 +110,16 @@ class AppAnnotationService:
|
|||
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
|
||||
# send batch add segments task
|
||||
redis_client.setnx(disable_app_annotation_job_key, "waiting")
|
||||
disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id)
|
||||
disable_annotation_reply_task.delay(str(job_id), app_id, current_tenant_id)
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -153,11 +149,10 @@ class AppAnnotationService:
|
|||
@classmethod
|
||||
def export_annotation_list_by_app_id(cls, app_id: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -174,11 +169,10 @@ class AppAnnotationService:
|
|||
@classmethod
|
||||
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -196,7 +190,7 @@ class AppAnnotationService:
|
|||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
args["question"],
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
)
|
||||
|
|
@ -205,11 +199,10 @@ class AppAnnotationService:
|
|||
@classmethod
|
||||
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -234,7 +227,7 @@ class AppAnnotationService:
|
|||
update_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
annotation.question,
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
app_id,
|
||||
app_annotation_setting.collection_binding_id,
|
||||
)
|
||||
|
|
@ -244,11 +237,10 @@ class AppAnnotationService:
|
|||
@classmethod
|
||||
def delete_app_annotation(cls, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -277,17 +269,16 @@ class AppAnnotationService:
|
|||
|
||||
if app_annotation_setting:
|
||||
delete_annotation_index_task.delay(
|
||||
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
|
||||
annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -317,7 +308,7 @@ class AppAnnotationService:
|
|||
for annotation, annotation_setting in annotations_to_delete:
|
||||
if annotation_setting:
|
||||
delete_annotation_index_task.delay(
|
||||
annotation.id, app_id, current_user.current_tenant_id, annotation_setting.collection_binding_id
|
||||
annotation.id, app_id, current_tenant_id, annotation_setting.collection_binding_id
|
||||
)
|
||||
|
||||
# Step 4: Bulk delete annotations in a single query
|
||||
|
|
@ -333,11 +324,10 @@ class AppAnnotationService:
|
|||
@classmethod
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -354,7 +344,7 @@ class AppAnnotationService:
|
|||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
# check annotation limit
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
annotation_quota_limit = features.annotation_quota_limit
|
||||
if annotation_quota_limit.limit < len(result) + annotation_quota_limit.size:
|
||||
|
|
@ -364,21 +354,18 @@ class AppAnnotationService:
|
|||
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, "waiting")
|
||||
batch_import_annotations_task.delay(
|
||||
str(job_id), result, app_id, current_user.current_tenant_id, current_user.id
|
||||
)
|
||||
batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id)
|
||||
except Exception as e:
|
||||
return {"error_msg": str(e)}
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -445,12 +432,11 @@ class AppAnnotationService:
|
|||
|
||||
@classmethod
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -481,12 +467,11 @@ class AppAnnotationService:
|
|||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -531,11 +516,10 @@ class AppAnnotationService:
|
|||
|
||||
@classmethod
|
||||
def clear_all_annotations(cls, app_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -558,7 +542,7 @@ class AppAnnotationService:
|
|||
# if annotation reply is enabled, delete annotation index
|
||||
if app_annotation_setting:
|
||||
delete_annotation_index_task.delay(
|
||||
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
|
||||
annotation.id, app_id, current_tenant_id, app_annotation_setting.collection_binding_id
|
||||
)
|
||||
|
||||
db.session.delete(annotation)
|
||||
|
|
|
|||
|
|
@ -26,10 +26,9 @@ class ApiKeyAuthService:
|
|||
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
|
||||
args["credentials"]["config"]["api_key"] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
|
||||
data_source_api_key_binding.tenant_id = tenant_id
|
||||
data_source_api_key_binding.category = args["category"]
|
||||
data_source_api_key_binding.provider = args["provider"]
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
|
||||
)
|
||||
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
|
@ -48,6 +47,8 @@ class ApiKeyAuthService:
|
|||
)
|
||||
if not data_source_api_key_bindings:
|
||||
return None
|
||||
if not data_source_api_key_bindings.credentials:
|
||||
return None
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return credentials
|
||||
|
||||
|
|
|
|||
|
|
@ -1470,7 +1470,7 @@ class DocumentService:
|
|||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
|
|
@ -1752,7 +1752,7 @@ class DocumentService:
|
|||
# dataset.collection_binding_id = dataset_collection_binding.id
|
||||
# if not dataset.retrieval_model:
|
||||
# default_retrieval_model = {
|
||||
# "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
# "search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
# "reranking_enable": False,
|
||||
# "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
# "top_k": 2,
|
||||
|
|
@ -2205,7 +2205,7 @@ class DocumentService:
|
|||
retrieval_model = knowledge_config.retrieval_model
|
||||
else:
|
||||
retrieval_model = RetrievalModel(
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH,
|
||||
reranking_enable=False,
|
||||
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
||||
top_k=4,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import time
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -18,6 +17,7 @@ from core.tools.entities.tool_entities import CredentialType
|
|||
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
|
@ -93,6 +93,8 @@ class DatasourceProviderService:
|
|||
"""
|
||||
get credential by id
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
if credential_id:
|
||||
datasource_provider = (
|
||||
|
|
@ -157,6 +159,8 @@ class DatasourceProviderService:
|
|||
"""
|
||||
get all datasource credentials by provider
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
datasource_providers = (
|
||||
session.query(DatasourceProvider)
|
||||
|
|
@ -604,6 +608,8 @@ class DatasourceProviderService:
|
|||
"""
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
|
|
@ -901,6 +907,8 @@ class DatasourceProviderService:
|
|||
"""
|
||||
update datasource credentials.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from typing import Literal
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class ParentMode(StrEnum):
|
||||
FULL_DOC = "full-doc"
|
||||
|
|
@ -95,7 +97,7 @@ class WeightModel(BaseModel):
|
|||
|
||||
|
||||
class RetrievalModel(BaseModel):
|
||||
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
|
||||
search_method: RetrievalMethod
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModel | None = None
|
||||
reranking_mode: str | None = None
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from typing import Literal
|
|||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
icon: str
|
||||
|
|
@ -83,7 +85,7 @@ class RetrievalSetting(BaseModel):
|
|||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"]
|
||||
search_method: RetrievalMethod
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ from core.model_runtime.entities.provider_entities import (
|
|||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class CustomConfigurationStatus(Enum):
|
||||
class CustomConfigurationStatus(StrEnum):
|
||||
"""
|
||||
Enum class for custom configuration status.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -88,9 +88,9 @@ class ExternalDatasetService:
|
|||
else:
|
||||
raise ValueError(f"invalid endpoint: {endpoint}")
|
||||
try:
|
||||
response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
|
||||
except Exception:
|
||||
raise ValueError(f"failed to connect to the endpoint: {endpoint}")
|
||||
response = ssrf_proxy.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
|
||||
except Exception as e:
|
||||
raise ValueError(f"failed to connect to the endpoint: {endpoint}") from e
|
||||
if response.status_code == 502:
|
||||
raise ValueError(f"Bad Gateway: failed to connect to the endpoint: {endpoint}")
|
||||
if response.status_code == 404:
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class HitTestingService:
|
|||
if metadata_condition and not document_ids_filter:
|
||||
return cls.compact_retrieve_response(query, [])
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k", 4),
|
||||
|
|
|
|||
|
|
@ -102,6 +102,15 @@ class OpsService:
|
|||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://arms.console.aliyun.com/"})
|
||||
|
||||
if tracing_provider == "tencent" and (
|
||||
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
|
||||
):
|
||||
try:
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
|
||||
new_decrypt_tracing_config.update({"project_url": project_url})
|
||||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://console.cloud.tencent.com/apm"})
|
||||
|
||||
trace_config_data.tracing_config = new_decrypt_tracing_config
|
||||
return trace_config_data.to_dict()
|
||||
|
||||
|
|
@ -144,7 +153,7 @@ class OpsService:
|
|||
project_url = f"{tracing_config.get('host')}/project/{project_key}"
|
||||
except Exception:
|
||||
project_url = None
|
||||
elif tracing_provider in ("langsmith", "opik"):
|
||||
elif tracing_provider in ("langsmith", "opik", "tencent"):
|
||||
try:
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from flask_login import current_user
|
|||
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
|
||||
|
|
@ -164,7 +165,7 @@ class RagPipelineTransformService:
|
|||
if retrieval_model:
|
||||
retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
|
||||
if indexing_technique == "economy":
|
||||
retrieval_setting.search_method = "keyword_search"
|
||||
retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
knowledge_configuration.retrieval_model = retrieval_setting
|
||||
else:
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ class ApiToolManageService:
|
|||
description=extra_info.get("description", ""),
|
||||
schema_type_str=schema_type,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str={},
|
||||
credentials_str="{}",
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -683,7 +683,7 @@ class BuiltinToolManageService:
|
|||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
original_params = encrypter.decrypt(custom_client_params.oauth_params)
|
||||
new_params: dict = {
|
||||
new_params = {
|
||||
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||
for key, value in client_params.items()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -188,6 +188,8 @@ class MCPToolManageService:
|
|||
raise
|
||||
|
||||
user = mcp_provider.load_user()
|
||||
if not mcp_provider.icon:
|
||||
raise ValueError("MCP provider icon is required")
|
||||
return ToolProviderApiEntity(
|
||||
id=mcp_provider.id,
|
||||
name=mcp_provider.name,
|
||||
|
|
|
|||
|
|
@ -152,7 +152,8 @@ class ToolTransformService:
|
|||
|
||||
if decrypt_credentials:
|
||||
credentials = db_provider.credentials
|
||||
|
||||
if not db_provider.tenant_id:
|
||||
raise ValueError(f"Required tenant_id is missing for BuiltinToolProvider with id {db_provider.id}")
|
||||
# init tool configuration
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue