/sync")
diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py
index 53dab3298f..ea23e097d0 100644
--- a/api/controllers/console/auth/error.py
+++ b/api/controllers/console/auth/error.py
@@ -2,31 +2,30 @@ from libs.exception import BaseHTTPException
class ApiKeyAuthFailedError(BaseHTTPException):
- error_code = 'auth_failed'
+ error_code = "auth_failed"
description = "{message}"
code = 500
class InvalidEmailError(BaseHTTPException):
- error_code = 'invalid_email'
+ error_code = "invalid_email"
description = "The email address is not valid."
code = 400
class PasswordMismatchError(BaseHTTPException):
- error_code = 'password_mismatch'
+ error_code = "password_mismatch"
description = "The passwords do not match."
code = 400
class InvalidTokenError(BaseHTTPException):
- error_code = 'invalid_or_expired_token'
+ error_code = "invalid_or_expired_token"
description = "The token is invalid or has expired."
code = 400
class PasswordResetRateLimitExceededError(BaseHTTPException):
- error_code = 'password_reset_rate_limit_exceeded'
+ error_code = "password_reset_rate_limit_exceeded"
description = "Password reset rate limit exceeded. Try again later."
code = 429
-
diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py
index d78be770ab..0b01a4906a 100644
--- a/api/controllers/console/auth/forgot_password.py
+++ b/api/controllers/console/auth/forgot_password.py
@@ -21,14 +21,13 @@ from services.errors.account import RateLimitExceededError
class ForgotPasswordSendEmailApi(Resource):
-
@setup_required
def post(self):
parser = reqparse.RequestParser()
- parser.add_argument('email', type=str, required=True, location='json')
+ parser.add_argument("email", type=str, required=True, location="json")
args = parser.parse_args()
- email = args['email']
+ email = args["email"]
if not email_validate(email):
raise InvalidEmailError()
@@ -49,38 +48,36 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource):
-
@setup_required
def post(self):
parser = reqparse.RequestParser()
- parser.add_argument('token', type=str, required=True, nullable=False, location='json')
+ parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
- token = args['token']
+ token = args["token"]
reset_data = AccountService.get_reset_password_data(token)
if reset_data is None:
- return {'is_valid': False, 'email': None}
- return {'is_valid': True, 'email': reset_data.get('email')}
+ return {"is_valid": False, "email": None}
+ return {"is_valid": True, "email": reset_data.get("email")}
class ForgotPasswordResetApi(Resource):
-
@setup_required
def post(self):
parser = reqparse.RequestParser()
- parser.add_argument('token', type=str, required=True, nullable=False, location='json')
- parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
- parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
+ parser.add_argument("token", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
+ parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
args = parser.parse_args()
- new_password = args['new_password']
- password_confirm = args['password_confirm']
+ new_password = args["new_password"]
+ password_confirm = args["password_confirm"]
if str(new_password).strip() != str(password_confirm).strip():
raise PasswordMismatchError()
- token = args['token']
+ token = args["token"]
reset_data = AccountService.get_reset_password_data(token)
if reset_data is None:
@@ -94,14 +91,14 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
- account = Account.query.filter_by(email=reset_data.get('email')).first()
+ account = Account.query.filter_by(email=reset_data.get("email")).first()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
- return {'result': 'success'}
+ return {"result": "success"}
-api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
-api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
-api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')
+api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
+api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
+api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py
index c135ece67e..62837af2b9 100644
--- a/api/controllers/console/auth/login.py
+++ b/api/controllers/console/auth/login.py
@@ -20,37 +20,39 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
- parser.add_argument('email', type=email, required=True, location='json')
- parser.add_argument('password', type=valid_password, required=True, location='json')
- parser.add_argument('remember_me', type=bool, required=False, default=False, location='json')
+ parser.add_argument("email", type=email, required=True, location="json")
+ parser.add_argument("password", type=valid_password, required=True, location="json")
+ parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
args = parser.parse_args()
# todo: Verify the recaptcha
try:
- account = AccountService.authenticate(args['email'], args['password'])
+ account = AccountService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError as e:
- return {'code': 'unauthorized', 'message': str(e)}, 401
+ return {"code": "unauthorized", "message": str(e)}, 401
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
- return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
+ return {
+ "result": "fail",
+ "data": "workspace not found, please contact system admin to invite you to join in a workspace",
+ }
token = AccountService.login(account, ip_address=get_remote_ip(request))
- return {'result': 'success', 'data': token}
+ return {"result": "success", "data": token}
class LogoutApi(Resource):
-
@setup_required
def get(self):
account = cast(Account, flask_login.current_user)
- token = request.headers.get('Authorization', '').split(' ')[1]
+ token = request.headers.get("Authorization", "").split(" ")[1]
AccountService.logout(account=account, token=token)
flask_login.logout_user()
- return {'result': 'success'}
+ return {"result": "success"}
class ResetPasswordApi(Resource):
@@ -80,11 +82,11 @@ class ResetPasswordApi(Resource):
# 'subject': 'Reset your Dify password',
# 'html': """
# Dear User,
- # The Dify team has generated a new password for you, details as follows:
+ # The Dify team has generated a new password for you, details as follows:
# {new_password}
# Please change your password to log in as soon as possible.
# Regards,
- # The Dify Team
+ # The Dify Team
# """
# }
@@ -101,8 +103,8 @@ class ResetPasswordApi(Resource):
# # handle error
# pass
- return {'result': 'success'}
+ return {"result": "success"}
-api.add_resource(LoginApi, '/login')
-api.add_resource(LogoutApi, '/logout')
+api.add_resource(LoginApi, "/login")
+api.add_resource(LogoutApi, "/logout")
diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py
index 4a651bfe7b..ae1b49f3ec 100644
--- a/api/controllers/console/auth/oauth.py
+++ b/api/controllers/console/auth/oauth.py
@@ -25,7 +25,7 @@ def get_oauth_providers():
github_oauth = GitHubOAuth(
client_id=dify_config.GITHUB_CLIENT_ID,
client_secret=dify_config.GITHUB_CLIENT_SECRET,
- redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github',
+ redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github",
)
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
google_oauth = None
@@ -33,10 +33,10 @@ def get_oauth_providers():
google_oauth = GoogleOAuth(
client_id=dify_config.GOOGLE_CLIENT_ID,
client_secret=dify_config.GOOGLE_CLIENT_SECRET,
- redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google',
+ redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google",
)
- OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth}
+ OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth}
return OAUTH_PROVIDERS
@@ -47,7 +47,7 @@ class OAuthLogin(Resource):
oauth_provider = OAUTH_PROVIDERS.get(provider)
print(vars(oauth_provider))
if not oauth_provider:
- return {'error': 'Invalid provider'}, 400
+ return {"error": "Invalid provider"}, 400
auth_url = oauth_provider.get_authorization_url()
return redirect(auth_url)
@@ -59,20 +59,20 @@ class OAuthCallback(Resource):
with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider)
if not oauth_provider:
- return {'error': 'Invalid provider'}, 400
+ return {"error": "Invalid provider"}, 400
- code = request.args.get('code')
+ code = request.args.get("code")
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.HTTPError as e:
- logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}')
- return {'error': 'OAuth process failed'}, 400
+ logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
+ return {"error": "OAuth process failed"}, 400
account = _generate_account(provider, user_info)
# Check account status
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
- return {'error': 'Account is banned or closed.'}, 403
+ return {"error": "Account is banned or closed."}, 403
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
@@ -83,7 +83,7 @@ class OAuthCallback(Resource):
token = AccountService.login(account, ip_address=get_remote_ip(request))
- return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}')
+ return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
@@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
if not account:
# Create account
- account_name = user_info.name if user_info.name else 'Dify'
+ account_name = user_info.name if user_info.name else "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
)
@@ -121,5 +121,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
return account
-api.add_resource(OAuthLogin, '/oauth/login/')
-api.add_resource(OAuthCallback, '/oauth/authorize/')
+api.add_resource(OAuthLogin, "/oauth/login/")
+api.add_resource(OAuthCallback, "/oauth/authorize/")
diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py
index 72a6129efa..9a1d914869 100644
--- a/api/controllers/console/billing/billing.py
+++ b/api/controllers/console/billing/billing.py
@@ -9,28 +9,24 @@ from services.billing_service import BillingService
class Subscription(Resource):
-
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
-
parser = reqparse.RequestParser()
- parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
- parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
+ parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
+ parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
args = parser.parse_args()
BillingService.is_tenant_owner_or_admin(current_user)
- return BillingService.get_subscription(args['plan'],
- args['interval'],
- current_user.email,
- current_user.current_tenant_id)
+ return BillingService.get_subscription(
+ args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
+ )
class Invoices(Resource):
-
@setup_required
@login_required
@account_initialization_required
@@ -40,5 +36,5 @@ class Invoices(Resource):
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
-api.add_resource(Subscription, '/billing/subscription')
-api.add_resource(Invoices, '/billing/invoices')
+api.add_resource(Subscription, "/billing/subscription")
+api.add_resource(Invoices, "/billing/invoices")
diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py
index 0ca0f0a856..0e1acab946 100644
--- a/api/controllers/console/datasets/data_source.py
+++ b/api/controllers/console/datasets/data_source.py
@@ -22,19 +22,22 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task
class DataSourceApi(Resource):
-
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_fields)
def get(self):
# get workspace data source integrates
- data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
- DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
- DataSourceOauthBinding.disabled == False
- ).all()
+ data_source_integrates = (
+ db.session.query(DataSourceOauthBinding)
+ .filter(
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+ DataSourceOauthBinding.disabled == False,
+ )
+ .all()
+ )
- base_url = request.url_root.rstrip('/')
+ base_url = request.url_root.rstrip("/")
data_source_oauth_base_path = "/console/api/oauth/data-source"
providers = ["notion"]
@@ -44,26 +47,30 @@ class DataSourceApi(Resource):
existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates)
if existing_integrates:
for existing_integrate in list(existing_integrates):
- integrate_data.append({
- 'id': existing_integrate.id,
- 'provider': provider,
- 'created_at': existing_integrate.created_at,
- 'is_bound': True,
- 'disabled': existing_integrate.disabled,
- 'source_info': existing_integrate.source_info,
- 'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
- })
+ integrate_data.append(
+ {
+ "id": existing_integrate.id,
+ "provider": provider,
+ "created_at": existing_integrate.created_at,
+ "is_bound": True,
+ "disabled": existing_integrate.disabled,
+ "source_info": existing_integrate.source_info,
+ "link": f"{base_url}{data_source_oauth_base_path}/{provider}",
+ }
+ )
else:
- integrate_data.append({
- 'id': None,
- 'provider': provider,
- 'created_at': None,
- 'source_info': None,
- 'is_bound': False,
- 'disabled': None,
- 'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
- })
- return {'data': integrate_data}, 200
+ integrate_data.append(
+ {
+ "id": None,
+ "provider": provider,
+ "created_at": None,
+ "source_info": None,
+ "is_bound": False,
+ "disabled": None,
+ "link": f"{base_url}{data_source_oauth_base_path}/{provider}",
+ }
+ )
+ return {"data": integrate_data}, 200
@setup_required
@login_required
@@ -71,92 +78,82 @@ class DataSourceApi(Resource):
def patch(self, binding_id, action):
binding_id = str(binding_id)
action = str(action)
- data_source_binding = DataSourceOauthBinding.query.filter_by(
- id=binding_id
- ).first()
+ data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
if data_source_binding is None:
- raise NotFound('Data source binding not found.')
+ raise NotFound("Data source binding not found.")
# enable binding
- if action == 'enable':
+ if action == "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(data_source_binding)
db.session.commit()
else:
- raise ValueError('Data source is not disabled.')
+ raise ValueError("Data source is not disabled.")
# disable binding
- if action == 'disable':
+ if action == "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(data_source_binding)
db.session.commit()
else:
- raise ValueError('Data source is disabled.')
- return {'result': 'success'}, 200
+ raise ValueError("Data source is disabled.")
+ return {"result": "success"}, 200
class DataSourceNotionListApi(Resource):
-
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_notion_info_list_fields)
def get(self):
- dataset_id = request.args.get('dataset_id', default=None, type=str)
+ dataset_id = request.args.get("dataset_id", default=None, type=str)
exist_page_ids = []
# import notion in the exist dataset
if dataset_id:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
- raise NotFound('Dataset not found.')
- if dataset.data_source_type != 'notion_import':
- raise ValueError('Dataset is not notion type.')
+ raise NotFound("Dataset not found.")
+ if dataset.data_source_type != "notion_import":
+ raise ValueError("Dataset is not notion type.")
documents = Document.query.filter_by(
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
- data_source_type='notion_import',
- enabled=True
+ data_source_type="notion_import",
+ enabled=True,
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
- exist_page_ids.append(data_source_info['notion_page_id'])
+ exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
data_source_bindings = DataSourceOauthBinding.query.filter_by(
- tenant_id=current_user.current_tenant_id,
- provider='notion',
- disabled=False
+ tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
).all()
if not data_source_bindings:
- return {
- 'notion_info': []
- }, 200
+ return {"notion_info": []}, 200
pre_import_info_list = []
for data_source_binding in data_source_bindings:
source_info = data_source_binding.source_info
- pages = source_info['pages']
+ pages = source_info["pages"]
# Filter out already bound pages
for page in pages:
- if page['page_id'] in exist_page_ids:
- page['is_bound'] = True
+ if page["page_id"] in exist_page_ids:
+ page["is_bound"] = True
else:
- page['is_bound'] = False
+ page["is_bound"] = False
pre_import_info = {
- 'workspace_name': source_info['workspace_name'],
- 'workspace_icon': source_info['workspace_icon'],
- 'workspace_id': source_info['workspace_id'],
- 'pages': pages,
+ "workspace_name": source_info["workspace_name"],
+ "workspace_icon": source_info["workspace_icon"],
+ "workspace_id": source_info["workspace_id"],
+ "pages": pages,
}
pre_import_info_list.append(pre_import_info)
- return {
- 'notion_info': pre_import_info_list
- }, 200
+ return {"notion_info": pre_import_info_list}, 200
class DataSourceNotionApi(Resource):
-
@setup_required
@login_required
@account_initialization_required
@@ -166,64 +163,67 @@ class DataSourceNotionApi(Resource):
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
- DataSourceOauthBinding.provider == 'notion',
+ DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
- DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
+ DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
).first()
if not data_source_binding:
- raise NotFound('Data source binding not found.')
+ raise NotFound("Data source binding not found.")
extractor = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=data_source_binding.access_token,
- tenant_id=current_user.current_tenant_id
+ tenant_id=current_user.current_tenant_id,
)
text_docs = extractor.extract()
- return {
- 'content': "\n".join([doc.page_content for doc in text_docs])
- }, 200
+ return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
- parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
- parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
- parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
- parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
+ parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
+ parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
+ parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
+ parser.add_argument(
+ "doc_language", type=str, default="English", required=False, nullable=False, location="json"
+ )
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
- notion_info_list = args['notion_info_list']
+ notion_info_list = args["notion_info_list"]
extract_settings = []
for notion_info in notion_info_list:
- workspace_id = notion_info['workspace_id']
- for page in notion_info['pages']:
+ workspace_id = notion_info["workspace_id"]
+ for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": workspace_id,
- "notion_obj_id": page['page_id'],
- "notion_page_type": page['type'],
- "tenant_id": current_user.current_tenant_id
+ "notion_obj_id": page["page_id"],
+ "notion_page_type": page["type"],
+ "tenant_id": current_user.current_tenant_id,
},
- document_model=args['doc_form']
+ document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
indexing_runner = IndexingRunner()
- response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
- args['process_rule'], args['doc_form'],
- args['doc_language'])
+ response = indexing_runner.indexing_estimate(
+ current_user.current_tenant_id,
+ extract_settings,
+ args["process_rule"],
+ args["doc_form"],
+ args["doc_language"],
+ )
return response, 200
class DataSourceNotionDatasetSyncApi(Resource):
-
@setup_required
@login_required
@account_initialization_required
@@ -240,7 +240,6 @@ class DataSourceNotionDatasetSyncApi(Resource):
class DataSourceNotionDocumentSyncApi(Resource):
-
@setup_required
@login_required
@account_initialization_required
@@ -258,10 +257,14 @@ class DataSourceNotionDocumentSyncApi(Resource):
return 200
-api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates//')
-api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
-api.add_resource(DataSourceNotionApi,
- '/notion/workspaces//pages///preview',
- '/datasets/notion-indexing-estimate')
-api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets//notion/sync')
-api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets//documents//notion/sync')
+api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates//")
+api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
+api.add_resource(
+ DataSourceNotionApi,
+ "/notion/workspaces//pages///preview",
+ "/datasets/notion-indexing-estimate",
+)
+api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets//notion/sync")
+api.add_resource(
+ DataSourceNotionDocumentSyncApi, "/datasets//documents//notion/sync"
+)
diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py
index b9a1c25154..d369730594 100644
--- a/api/controllers/console/datasets/datasets.py
+++ b/api/controllers/console/datasets/datasets.py
@@ -31,45 +31,40 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
- raise ValueError('Name must be between 1 to 40 characters.')
+ raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
- raise ValueError('Description cannot exceed 400 characters.')
+ raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetListApi(Resource):
-
@setup_required
@login_required
@account_initialization_required
def get(self):
- page = request.args.get('page', default=1, type=int)
- limit = request.args.get('limit', default=20, type=int)
- ids = request.args.getlist('ids')
- provider = request.args.get('provider', default="vendor")
- search = request.args.get('keyword', default=None, type=str)
- tag_ids = request.args.getlist('tag_ids')
+ page = request.args.get("page", default=1, type=int)
+ limit = request.args.get("limit", default=20, type=int)
+ ids = request.args.getlist("ids")
+ provider = request.args.get("provider", default="vendor")
+ search = request.args.get("keyword", default=None, type=str)
+ tag_ids = request.args.getlist("tag_ids")
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
else:
- datasets, total = DatasetService.get_datasets(page, limit, provider,
- current_user.current_tenant_id, current_user, search, tag_ids)
+ datasets, total = DatasetService.get_datasets(
+ page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids
+ )
# check embedding setting
provider_manager = ProviderManager()
- configurations = provider_manager.get_configurations(
- tenant_id=current_user.current_tenant_id
- )
+ configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
- embedding_models = configurations.get_models(
- model_type=ModelType.TEXT_EMBEDDING,
- only_active=True
- )
+ embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
model_names = []
for embedding_model in embedding_models:
@@ -77,28 +72,22 @@ class DatasetListApi(Resource):
data = marshal(datasets, dataset_detail_fields)
for item in data:
- if item['indexing_technique'] == 'high_quality':
+ if item["indexing_technique"] == "high_quality":
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
- item['embedding_available'] = True
+ item["embedding_available"] = True
else:
- item['embedding_available'] = False
+ item["embedding_available"] = False
else:
- item['embedding_available'] = True
+ item["embedding_available"] = True
- if item.get('permission') == 'partial_members':
- part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
- item.update({'partial_member_list': part_users_list})
+ if item.get("permission") == "partial_members":
+ part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"])
+ item.update({"partial_member_list": part_users_list})
else:
- item.update({'partial_member_list': []})
+ item.update({"partial_member_list": []})
- response = {
- 'data': data,
- 'has_more': len(datasets) == limit,
- 'limit': limit,
- 'total': total,
- 'page': page
- }
+ response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200
@setup_required
@@ -106,13 +95,21 @@ class DatasetListApi(Resource):
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
- parser.add_argument('name', nullable=False, required=True,
- help='type is required. Name must be between 1 to 40 characters.',
- type=_validate_name)
- parser.add_argument('indexing_technique', type=str, location='json',
- choices=Dataset.INDEXING_TECHNIQUE_LIST,
- nullable=True,
- help='Invalid indexing technique.')
+ parser.add_argument(
+ "name",
+ nullable=False,
+ required=True,
+ help="type is required. Name must be between 1 to 40 characters.",
+ type=_validate_name,
+ )
+ parser.add_argument(
+ "indexing_technique",
+ type=str,
+ location="json",
+ choices=Dataset.INDEXING_TECHNIQUE_LIST,
+ nullable=True,
+ help="Invalid indexing technique.",
+ )
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@@ -122,9 +119,9 @@ class DatasetListApi(Resource):
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id,
- name=args['name'],
- indexing_technique=args['indexing_technique'],
- account=current_user
+ name=args["name"],
+ indexing_technique=args["indexing_technique"],
+ account=current_user,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@@ -142,42 +139,36 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
try:
- DatasetService.check_dataset_permission(
- dataset, current_user)
+ DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
- if data.get('permission') == 'partial_members':
+ if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
- data.update({'partial_member_list': part_users_list})
+ data.update({"partial_member_list": part_users_list})
# check embedding setting
provider_manager = ProviderManager()
- configurations = provider_manager.get_configurations(
- tenant_id=current_user.current_tenant_id
- )
+ configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
- embedding_models = configurations.get_models(
- model_type=ModelType.TEXT_EMBEDDING,
- only_active=True
- )
+ embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
model_names = []
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
- if data['indexing_technique'] == 'high_quality':
+ if data["indexing_technique"] == "high_quality":
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
- data['embedding_available'] = True
+ data["embedding_available"] = True
else:
- data['embedding_available'] = False
+ data["embedding_available"] = False
else:
- data['embedding_available'] = True
+ data["embedding_available"] = True
- if data.get('permission') == 'partial_members':
+ if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
- data.update({'partial_member_list': part_users_list})
+ data.update({"partial_member_list": part_users_list})
return data, 200
@@ -191,42 +182,49 @@ class DatasetApi(Resource):
raise NotFound("Dataset not found.")
parser = reqparse.RequestParser()
- parser.add_argument('name', nullable=False,
- help='type is required. Name must be between 1 to 40 characters.',
- type=_validate_name)
- parser.add_argument('description',
- location='json', store_missing=False,
- type=_validate_description_length)
- parser.add_argument('indexing_technique', type=str, location='json',
- choices=Dataset.INDEXING_TECHNIQUE_LIST,
- nullable=True,
- help='Invalid indexing technique.')
- parser.add_argument('permission', type=str, location='json', choices=(
- DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.'
- )
- parser.add_argument('embedding_model', type=str,
- location='json', help='Invalid embedding model.')
- parser.add_argument('embedding_model_provider', type=str,
- location='json', help='Invalid embedding model provider.')
- parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
- parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
+ parser.add_argument(
+ "name",
+ nullable=False,
+ help="type is required. Name must be between 1 to 40 characters.",
+ type=_validate_name,
+ )
+ parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
+ parser.add_argument(
+ "indexing_technique",
+ type=str,
+ location="json",
+ choices=Dataset.INDEXING_TECHNIQUE_LIST,
+ nullable=True,
+ help="Invalid indexing technique.",
+ )
+ parser.add_argument(
+ "permission",
+ type=str,
+ location="json",
+ choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
+ help="Invalid permission.",
+ )
+ parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
+ parser.add_argument(
+ "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
+ )
+ parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
+ parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
args = parser.parse_args()
data = request.get_json()
# check embedding model setting
- if data.get('indexing_technique') == 'high_quality':
- DatasetService.check_embedding_model_setting(dataset.tenant_id,
- data.get('embedding_model_provider'),
- data.get('embedding_model')
- )
+ if data.get("indexing_technique") == "high_quality":
+ DatasetService.check_embedding_model_setting(
+ dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
+ )
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
- current_user, dataset, data.get('permission'), data.get('partial_member_list')
+ current_user, dataset, data.get("permission"), data.get("partial_member_list")
)
- dataset = DatasetService.update_dataset(
- dataset_id_str, args, current_user)
+ dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
@@ -234,16 +232,19 @@ class DatasetApi(Resource):
result_data = marshal(dataset, dataset_detail_fields)
tenant_id = current_user.current_tenant_id
- if data.get('partial_member_list') and data.get('permission') == 'partial_members':
+ if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
- tenant_id, dataset_id_str, data.get('partial_member_list')
+ tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members
- elif data.get('permission') == DatasetPermissionEnum.ONLY_ME or data.get('permission') == DatasetPermissionEnum.ALL_TEAM:
+ elif (
+ data.get("permission") == DatasetPermissionEnum.ONLY_ME
+ or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
+ ):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
- result_data.update({'partial_member_list': partial_member_list})
+ result_data.update({"partial_member_list": partial_member_list})
return result_data, 200
@@ -260,12 +261,13 @@ class DatasetApi(Resource):
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
- return {'result': 'success'}, 204
+ return {"result": "success"}, 204
else:
raise NotFound("Dataset not found.")
except services.errors.dataset.DatasetInUseError:
raise DatasetInUseError()
+
class DatasetUseCheckApi(Resource):
@setup_required
@login_required
@@ -274,10 +276,10 @@ class DatasetUseCheckApi(Resource):
dataset_id_str = str(dataset_id)
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
- return {'is_using': dataset_is_using}, 200
+ return {"is_using": dataset_is_using}, 200
+
class DatasetQueryApi(Resource):
-
@setup_required
@login_required
@account_initialization_required
@@ -292,51 +294,53 @@ class DatasetQueryApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
- page = request.args.get('page', default=1, type=int)
- limit = request.args.get('limit', default=20, type=int)
+ page = request.args.get("page", default=1, type=int)
+ limit = request.args.get("limit", default=20, type=int)
- dataset_queries, total = DatasetService.get_dataset_queries(
- dataset_id=dataset.id,
- page=page,
- per_page=limit
- )
+ dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
response = {
- 'data': marshal(dataset_queries, dataset_query_detail_fields),
- 'has_more': len(dataset_queries) == limit,
- 'limit': limit,
- 'total': total,
- 'page': page
+ "data": marshal(dataset_queries, dataset_query_detail_fields),
+ "has_more": len(dataset_queries) == limit,
+ "limit": limit,
+ "total": total,
+ "page": page,
}
return response, 200
class DatasetIndexingEstimateApi(Resource):
-
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
- parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
- parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
- parser.add_argument('indexing_technique', type=str, required=True,
- choices=Dataset.INDEXING_TECHNIQUE_LIST,
- nullable=True, location='json')
- parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
- parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
- parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
- location='json')
+ parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
+ parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
+ parser.add_argument(
+ "indexing_technique",
+ type=str,
+ required=True,
+ choices=Dataset.INDEXING_TECHNIQUE_LIST,
+ nullable=True,
+ location="json",
+ )
+ parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
+ parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
+ parser.add_argument(
+ "doc_language", type=str, default="English", required=False, nullable=False, location="json"
+ )
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
extract_settings = []
- if args['info_list']['data_source_type'] == 'upload_file':
- file_ids = args['info_list']['file_info_list']['file_ids']
- file_details = db.session.query(UploadFile).filter(
- UploadFile.tenant_id == current_user.current_tenant_id,
- UploadFile.id.in_(file_ids)
- ).all()
+ if args["info_list"]["data_source_type"] == "upload_file":
+ file_ids = args["info_list"]["file_info_list"]["file_ids"]
+ file_details = (
+ db.session.query(UploadFile)
+ .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
+ .all()
+ )
if file_details is None:
raise NotFound("File not found.")
@@ -344,55 +348,58 @@ class DatasetIndexingEstimateApi(Resource):
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
- datasource_type="upload_file",
- upload_file=file_detail,
- document_model=args['doc_form']
+ datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
)
extract_settings.append(extract_setting)
- elif args['info_list']['data_source_type'] == 'notion_import':
- notion_info_list = args['info_list']['notion_info_list']
+ elif args["info_list"]["data_source_type"] == "notion_import":
+ notion_info_list = args["info_list"]["notion_info_list"]
for notion_info in notion_info_list:
- workspace_id = notion_info['workspace_id']
- for page in notion_info['pages']:
+ workspace_id = notion_info["workspace_id"]
+ for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": workspace_id,
- "notion_obj_id": page['page_id'],
- "notion_page_type": page['type'],
- "tenant_id": current_user.current_tenant_id
+ "notion_obj_id": page["page_id"],
+ "notion_page_type": page["type"],
+ "tenant_id": current_user.current_tenant_id,
},
- document_model=args['doc_form']
+ document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
- elif args['info_list']['data_source_type'] == 'website_crawl':
- website_info_list = args['info_list']['website_info_list']
- for url in website_info_list['urls']:
+ elif args["info_list"]["data_source_type"] == "website_crawl":
+ website_info_list = args["info_list"]["website_info_list"]
+ for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type="website_crawl",
website_info={
- "provider": website_info_list['provider'],
- "job_id": website_info_list['job_id'],
+ "provider": website_info_list["provider"],
+ "job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
- "mode": 'crawl',
- "only_main_content": website_info_list['only_main_content']
+ "mode": "crawl",
+ "only_main_content": website_info_list["only_main_content"],
},
- document_model=args['doc_form']
+ document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
else:
- raise ValueError('Data source type not support')
+ raise ValueError("Data source type not support")
indexing_runner = IndexingRunner()
try:
- response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
- args['process_rule'], args['doc_form'],
- args['doc_language'], args['dataset_id'],
- args['indexing_technique'])
+ response = indexing_runner.indexing_estimate(
+ current_user.current_tenant_id,
+ extract_settings,
+ args["process_rule"],
+ args["doc_form"],
+ args["doc_language"],
+ args["dataset_id"],
+ args["indexing_technique"],
+ )
except LLMBadRequestError:
raise ProviderNotInitializeError(
- "No Embedding Model available. Please configure a valid provider "
- "in the Settings -> Model Provider.")
+ "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
+ )
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except Exception as e:
@@ -402,7 +409,6 @@ class DatasetIndexingEstimateApi(Resource):
class DatasetRelatedAppListApi(Resource):
-
@setup_required
@login_required
@account_initialization_required
@@ -426,52 +432,52 @@ class DatasetRelatedAppListApi(Resource):
if app_model:
related_apps.append(app_model)
- return {
- 'data': related_apps,
- 'total': len(related_apps)
- }, 200
+ return {"data": related_apps, "total": len(related_apps)}, 200
class DatasetIndexingStatusApi(Resource):
-
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id = str(dataset_id)
- documents = db.session.query(Document).filter(
- Document.dataset_id == dataset_id,
- Document.tenant_id == current_user.current_tenant_id
- ).all()
+ documents = (
+ db.session.query(Document)
+ .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
+ .all()
+ )
documents_status = []
for document in documents:
- completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
- DocumentSegment.document_id == str(document.id),
- DocumentSegment.status != 're_segment').count()
- total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
- DocumentSegment.status != 're_segment').count()
+ completed_segments = DocumentSegment.query.filter(
+ DocumentSegment.completed_at.isnot(None),
+ DocumentSegment.document_id == str(document.id),
+ DocumentSegment.status != "re_segment",
+ ).count()
+ total_segments = DocumentSegment.query.filter(
+ DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
+ ).count()
document.completed_segments = completed_segments
document.total_segments = total_segments
documents_status.append(marshal(document, document_status_fields))
- data = {
- 'data': documents_status
- }
+ data = {"data": documents_status}
return data
class DatasetApiKeyApi(Resource):
max_keys = 10
- token_prefix = 'dataset-'
- resource_type = 'dataset'
+ token_prefix = "dataset-"
+ resource_type = "dataset"
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_list)
def get(self):
- keys = db.session.query(ApiToken). \
- filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
- all()
+ keys = (
+ db.session.query(ApiToken)
+ .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
+ .all()
+ )
return {"items": keys}
@setup_required
@@ -483,15 +489,17 @@ class DatasetApiKeyApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
- current_key_count = db.session.query(ApiToken). \
- filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
- count()
+ current_key_count = (
+ db.session.query(ApiToken)
+ .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
+ .count()
+ )
if current_key_count >= self.max_keys:
flask_restful.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
- code='max_keys_exceeded'
+ code="max_keys_exceeded",
)
key = ApiToken.generate_api_key(self.token_prefix, 24)
@@ -505,7 +513,7 @@ class DatasetApiKeyApi(Resource):
class DatasetApiDeleteApi(Resource):
- resource_type = 'dataset'
+ resource_type = "dataset"
@setup_required
@login_required
@@ -517,18 +525,23 @@ class DatasetApiDeleteApi(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
- key = db.session.query(ApiToken). \
- filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
- ApiToken.id == api_key_id). \
- first()
+ key = (
+ db.session.query(ApiToken)
+ .filter(
+ ApiToken.tenant_id == current_user.current_tenant_id,
+ ApiToken.type == self.resource_type,
+ ApiToken.id == api_key_id,
+ )
+ .first()
+ )
if key is None:
- flask_restful.abort(404, message='API key not found')
+ flask_restful.abort(404, message="API key not found")
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit()
- return {'result': 'success'}, 204
+ return {"result": "success"}, 204
class DatasetApiBaseUrlApi(Resource):
@@ -537,8 +550,10 @@ class DatasetApiBaseUrlApi(Resource):
@account_initialization_required
def get(self):
return {
- 'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL
- else request.host_url.rstrip('/')) + '/v1'
+ "api_base_url": (
+ dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
+ )
+ + "/v1"
}
@@ -549,15 +564,26 @@ class DatasetRetrievalSettingApi(Resource):
def get(self):
vector_type = dify_config.VECTOR_STORE
match vector_type:
- case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
+ case (
+ VectorType.MILVUS
+ | VectorType.RELYT
+ | VectorType.PGVECTOR
+ | VectorType.TIDB_VECTOR
+ | VectorType.CHROMA
+ | VectorType.TENCENT
+ ):
+ return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
+ case (
+ VectorType.QDRANT
+ | VectorType.WEAVIATE
+ | VectorType.OPENSEARCH
+ | VectorType.ANALYTICDB
+ | VectorType.MYSCALE
+ | VectorType.ORACLE
+ | VectorType.ELASTICSEARCH
+ ):
return {
- 'retrieval_method': [
- RetrievalMethod.SEMANTIC_SEARCH.value
- ]
- }
- case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
- return {
- 'retrieval_method': [
+ "retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
@@ -573,15 +599,27 @@ class DatasetRetrievalSettingMockApi(Resource):
@account_initialization_required
def get(self, vector_type):
match vector_type:
- case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS:
+ case (
+ VectorType.MILVUS
+ | VectorType.RELYT
+ | VectorType.TIDB_VECTOR
+ | VectorType.CHROMA
+ | VectorType.TENCENT
+ | VectorType.PGVECTO_RS
+ ):
+ return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
+ case (
+ VectorType.QDRANT
+ | VectorType.WEAVIATE
+ | VectorType.OPENSEARCH
+ | VectorType.ANALYTICDB
+ | VectorType.MYSCALE
+ | VectorType.ORACLE
+ | VectorType.ELASTICSEARCH
+ | VectorType.PGVECTOR
+ ):
return {
- 'retrieval_method': [
- RetrievalMethod.SEMANTIC_SEARCH.value
- ]
- }
- case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR:
- return {
- 'retrieval_method': [
+ "retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
@@ -591,7 +629,6 @@ class DatasetRetrievalSettingMockApi(Resource):
raise ValueError(f"Unsupported vector db type {vector_type}.")
-
class DatasetErrorDocs(Resource):
@setup_required
@login_required
@@ -603,10 +640,7 @@ class DatasetErrorDocs(Resource):
raise NotFound("Dataset not found.")
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
- return {
- 'data': [marshal(item, document_status_fields) for item in results],
- 'total': len(results)
- }, 200
+ return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
class DatasetPermissionUserListApi(Resource):
@@ -626,21 +660,21 @@ class DatasetPermissionUserListApi(Resource):
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
return {
- 'data': partial_members_list,
+ "data": partial_members_list,
}, 200
-api.add_resource(DatasetListApi, '/datasets')
-api.add_resource(DatasetApi, '/datasets/')
-api.add_resource(DatasetUseCheckApi, '/datasets//use-check')
-api.add_resource(DatasetQueryApi, '/datasets//queries')
-api.add_resource(DatasetErrorDocs, '/datasets//error-docs')
-api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
-api.add_resource(DatasetRelatedAppListApi, '/datasets//related-apps')
-api.add_resource(DatasetIndexingStatusApi, '/datasets//indexing-status')
-api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
-api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/')
-api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
-api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
-api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/')
-api.add_resource(DatasetPermissionUserListApi, '/datasets//permission-part-users')
+api.add_resource(DatasetListApi, "/datasets")
+api.add_resource(DatasetApi, "/datasets/")
+api.add_resource(DatasetUseCheckApi, "/datasets//use-check")
+api.add_resource(DatasetQueryApi, "/datasets//queries")
+api.add_resource(DatasetErrorDocs, "/datasets//error-docs")
+api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate")
+api.add_resource(DatasetRelatedAppListApi, "/datasets//related-apps")
+api.add_resource(DatasetIndexingStatusApi, "/datasets//indexing-status")
+api.add_resource(DatasetApiKeyApi, "/datasets/api-keys")
+api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/")
+api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
+api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
+api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/")
+api.add_resource(DatasetPermissionUserListApi, "/datasets//permission-part-users")
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index 976b97660a..7d0b9f0460 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -57,7 +57,7 @@ class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
- raise NotFound('Dataset not found.')
+ raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
@@ -67,17 +67,17 @@ class DocumentResource(Resource):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
- raise NotFound('Document not found.')
+ raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id:
- raise Forbidden('No permission.')
+ raise Forbidden("No permission.")
return document
def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
- raise NotFound('Dataset not found.')
+ raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
@@ -87,7 +87,7 @@ class DocumentResource(Resource):
documents = DocumentService.get_batch_documents(dataset_id, batch)
if not documents:
- raise NotFound('Documents not found.')
+ raise NotFound("Documents not found.")
return documents
@@ -99,11 +99,11 @@ class GetProcessRuleApi(Resource):
def get(self):
req_data = request.args
- document_id = req_data.get('document_id')
+ document_id = req_data.get("document_id")
# get default rules
- mode = DocumentService.DEFAULT_RULES['mode']
- rules = DocumentService.DEFAULT_RULES['rules']
+ mode = DocumentService.DEFAULT_RULES["mode"]
+ rules = DocumentService.DEFAULT_RULES["rules"]
if document_id:
# get the latest process rule
document = Document.query.get_or_404(document_id)
@@ -111,7 +111,7 @@ class GetProcessRuleApi(Resource):
dataset = DatasetService.get_dataset(document.dataset_id)
if not dataset:
- raise NotFound('Dataset not found.')
+ raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
@@ -119,19 +119,18 @@ class GetProcessRuleApi(Resource):
raise Forbidden(str(e))
# get the latest process rule
- dataset_process_rule = db.session.query(DatasetProcessRule). \
- filter(DatasetProcessRule.dataset_id == document.dataset_id). \
- order_by(DatasetProcessRule.created_at.desc()). \
- limit(1). \
- one_or_none()
+ dataset_process_rule = (
+ db.session.query(DatasetProcessRule)
+ .filter(DatasetProcessRule.dataset_id == document.dataset_id)
+ .order_by(DatasetProcessRule.created_at.desc())
+ .limit(1)
+ .one_or_none()
+ )
if dataset_process_rule:
mode = dataset_process_rule.mode
rules = dataset_process_rule.rules_dict
- return {
- 'mode': mode,
- 'rules': rules
- }
+ return {"mode": mode, "rules": rules}
class DatasetDocumentListApi(Resource):
@@ -140,49 +139,48 @@ class DatasetDocumentListApi(Resource):
@account_initialization_required
def get(self, dataset_id):
dataset_id = str(dataset_id)
- page = request.args.get('page', default=1, type=int)
- limit = request.args.get('limit', default=20, type=int)
- search = request.args.get('keyword', default=None, type=str)
- sort = request.args.get('sort', default='-created_at', type=str)
+ page = request.args.get("page", default=1, type=int)
+ limit = request.args.get("limit", default=20, type=int)
+ search = request.args.get("keyword", default=None, type=str)
+ sort = request.args.get("sort", default="-created_at", type=str)
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
- fetch = string_to_bool(request.args.get('fetch', default='false'))
+ fetch = string_to_bool(request.args.get("fetch", default="false"))
except (ArgumentTypeError, ValueError, Exception) as e:
fetch = False
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
- raise NotFound('Dataset not found.')
+ raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
- query = Document.query.filter_by(
- dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
+ query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
if search:
- search = f'%{search}%'
+ search = f"%{search}%"
query = query.filter(Document.name.like(search))
- if sort.startswith('-'):
+ if sort.startswith("-"):
sort_logic = desc
sort = sort[1:]
else:
sort_logic = asc
- if sort == 'hit_count':
- sub_query = db.select(DocumentSegment.document_id,
- db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) \
- .group_by(DocumentSegment.document_id) \
+ if sort == "hit_count":
+ sub_query = (
+ db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
+ .group_by(DocumentSegment.document_id)
.subquery()
+ )
- query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \
- .order_by(
- sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
- sort_logic(Document.position),
- )
- elif sort == 'created_at':
+ query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
+ sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
+ sort_logic(Document.position),
+ )
+ elif sort == "created_at":
query = query.order_by(
sort_logic(Document.created_at),
sort_logic(Document.position),
@@ -193,48 +191,47 @@ class DatasetDocumentListApi(Resource):
desc(Document.position),
)
- paginated_documents = query.paginate(
- page=page, per_page=limit, max_per_page=100, error_out=False)
+ paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
if fetch:
for document in documents:
- completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
- DocumentSegment.document_id == str(document.id),
- DocumentSegment.status != 're_segment').count()
- total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
- DocumentSegment.status != 're_segment').count()
+ completed_segments = DocumentSegment.query.filter(
+ DocumentSegment.completed_at.isnot(None),
+ DocumentSegment.document_id == str(document.id),
+ DocumentSegment.status != "re_segment",
+ ).count()
+ total_segments = DocumentSegment.query.filter(
+ DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
+ ).count()
document.completed_segments = completed_segments
document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields)
else:
data = marshal(documents, document_fields)
response = {
- 'data': data,
- 'has_more': len(documents) == limit,
- 'limit': limit,
- 'total': paginated_documents.total,
- 'page': page
+ "data": data,
+ "has_more": len(documents) == limit,
+ "limit": limit,
+ "total": paginated_documents.total,
+ "page": page,
}
return response
- documents_and_batch_fields = {
- 'documents': fields.List(fields.Nested(document_fields)),
- 'batch': fields.String
- }
+ documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String}
@setup_required
@login_required
@account_initialization_required
@marshal_with(documents_and_batch_fields)
- @cloud_edition_billing_resource_check('vector_space')
+ @cloud_edition_billing_resource_check("vector_space")
def post(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
- raise NotFound('Dataset not found.')
+ raise NotFound("Dataset not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_dataset_editor:
@@ -246,21 +243,22 @@ class DatasetDocumentListApi(Resource):
raise Forbidden(str(e))
parser = reqparse.RequestParser()
- parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
- location='json')
- parser.add_argument('data_source', type=dict, required=False, location='json')
- parser.add_argument('process_rule', type=dict, required=False, location='json')
- parser.add_argument('duplicate', type=bool, default=True, nullable=False, location='json')
- parser.add_argument('original_document_id', type=str, required=False, location='json')
- parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
- parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
- location='json')
- parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
- location='json')
+ parser.add_argument(
+ "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
+ )
+ parser.add_argument("data_source", type=dict, required=False, location="json")
+ parser.add_argument("process_rule", type=dict, required=False, location="json")
+ parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
+ parser.add_argument("original_document_id", type=str, required=False, location="json")
+ parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
+ parser.add_argument(
+ "doc_language", type=str, default="English", required=False, nullable=False, location="json"
+ )
+ parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
args = parser.parse_args()
- if not dataset.indexing_technique and not args['indexing_technique']:
- raise ValueError('indexing_technique is required.')
+ if not dataset.indexing_technique and not args["indexing_technique"]:
+ raise ValueError("indexing_technique is required.")
# validate args
DocumentService.document_create_args_validate(args)
@@ -274,51 +272,53 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
- return {
- 'documents': documents,
- 'batch': batch
- }
+ return {"documents": documents, "batch": batch}
class DatasetInitApi(Resource):
-
@setup_required
@login_required
@account_initialization_required
@marshal_with(dataset_and_document_fields)
- @cloud_edition_billing_resource_check('vector_space')
+ @cloud_edition_billing_resource_check("vector_space")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
- parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, required=True,
- nullable=False, location='json')
- parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
- parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
- parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
- parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
- location='json')
- parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
- location='json')
+ parser.add_argument(
+ "indexing_technique",
+ type=str,
+ choices=Dataset.INDEXING_TECHNIQUE_LIST,
+ required=True,
+ nullable=False,
+ location="json",
+ )
+ parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
+ parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
+ parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
+ parser.add_argument(
+ "doc_language", type=str, default="English", required=False, nullable=False, location="json"
+ )
+ parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
- if args['indexing_technique'] == 'high_quality':
+ if args["indexing_technique"] == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_default_model_instance(
- tenant_id=current_user.current_tenant_id,
- model_type=ModelType.TEXT_EMBEDDING
+ tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
- "in the Settings -> Model Provider.")
+ "in the Settings -> Model Provider."
+ )
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -327,9 +327,7 @@ class DatasetInitApi(Resource):
try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
- tenant_id=current_user.current_tenant_id,
- document_data=args,
- account=current_user
+ tenant_id=current_user.current_tenant_id, document_data=args, account=current_user
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -338,17 +336,12 @@ class DatasetInitApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
- response = {
- 'dataset': dataset,
- 'documents': documents,
- 'batch': batch
- }
+ response = {"dataset": dataset, "documents": documents, "batch": batch}
return response
class DocumentIndexingEstimateApi(DocumentResource):
-
@setup_required
@login_required
@account_initialization_required
@@ -357,50 +350,49 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
- if document.indexing_status in ['completed', 'error']:
+ if document.indexing_status in ["completed", "error"]:
raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict()
- response = {
- "tokens": 0,
- "total_price": 0,
- "currency": "USD",
- "total_segments": 0,
- "preview": []
- }
+ response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
- if document.data_source_type == 'upload_file':
+ if document.data_source_type == "upload_file":
data_source_info = document.data_source_info_dict
- if data_source_info and 'upload_file_id' in data_source_info:
- file_id = data_source_info['upload_file_id']
+ if data_source_info and "upload_file_id" in data_source_info:
+ file_id = data_source_info["upload_file_id"]
- file = db.session.query(UploadFile).filter(
- UploadFile.tenant_id == document.tenant_id,
- UploadFile.id == file_id
- ).first()
+ file = (
+ db.session.query(UploadFile)
+ .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
+ .first()
+ )
# raise error if file not found
if not file:
- raise NotFound('File not found.')
+ raise NotFound("File not found.")
extract_setting = ExtractSetting(
- datasource_type="upload_file",
- upload_file=file,
- document_model=document.doc_form
+ datasource_type="upload_file", upload_file=file, document_model=document.doc_form
)
indexing_runner = IndexingRunner()
try:
- response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting],
- data_process_rule_dict, document.doc_form,
- 'English', dataset_id)
+ response = indexing_runner.indexing_estimate(
+ current_user.current_tenant_id,
+ [extract_setting],
+ data_process_rule_dict,
+ document.doc_form,
+ "English",
+ dataset_id,
+ )
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
- "in the Settings -> Model Provider.")
+ "in the Settings -> Model Provider."
+ )
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except Exception as e:
@@ -410,7 +402,6 @@ class DocumentIndexingEstimateApi(DocumentResource):
class DocumentBatchIndexingEstimateApi(DocumentResource):
-
@setup_required
@login_required
@account_initialization_required
@@ -418,13 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
dataset_id = str(dataset_id)
batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch)
- response = {
- "tokens": 0,
- "total_price": 0,
- "currency": "USD",
- "total_segments": 0,
- "preview": []
- }
+ response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
if not documents:
return response
data_process_rule = documents[0].dataset_process_rule
@@ -432,82 +417,83 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
info_list = []
extract_settings = []
for document in documents:
- if document.indexing_status in ['completed', 'error']:
+ if document.indexing_status in ["completed", "error"]:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
# format document files info
- if data_source_info and 'upload_file_id' in data_source_info:
- file_id = data_source_info['upload_file_id']
+ if data_source_info and "upload_file_id" in data_source_info:
+ file_id = data_source_info["upload_file_id"]
info_list.append(file_id)
# format document notion info
- elif data_source_info and 'notion_workspace_id' in data_source_info and 'notion_page_id' in data_source_info:
+ elif (
+ data_source_info and "notion_workspace_id" in data_source_info and "notion_page_id" in data_source_info
+ ):
pages = []
- page = {
- 'page_id': data_source_info['notion_page_id'],
- 'type': data_source_info['type']
- }
+ page = {"page_id": data_source_info["notion_page_id"], "type": data_source_info["type"]}
pages.append(page)
- notion_info = {
- 'workspace_id': data_source_info['notion_workspace_id'],
- 'pages': pages
- }
+ notion_info = {"workspace_id": data_source_info["notion_workspace_id"], "pages": pages}
info_list.append(notion_info)
- if document.data_source_type == 'upload_file':
- file_id = data_source_info['upload_file_id']
- file_detail = db.session.query(UploadFile).filter(
- UploadFile.tenant_id == current_user.current_tenant_id,
- UploadFile.id == file_id
- ).first()
+ if document.data_source_type == "upload_file":
+ file_id = data_source_info["upload_file_id"]
+ file_detail = (
+ db.session.query(UploadFile)
+ .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
+ .first()
+ )
if file_detail is None:
raise NotFound("File not found.")
extract_setting = ExtractSetting(
- datasource_type="upload_file",
- upload_file=file_detail,
- document_model=document.doc_form
+ datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
- elif document.data_source_type == 'notion_import':
+ elif document.data_source_type == "notion_import":
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
- "notion_workspace_id": data_source_info['notion_workspace_id'],
- "notion_obj_id": data_source_info['notion_page_id'],
- "notion_page_type": data_source_info['type'],
- "tenant_id": current_user.current_tenant_id
+ "notion_workspace_id": data_source_info["notion_workspace_id"],
+ "notion_obj_id": data_source_info["notion_page_id"],
+ "notion_page_type": data_source_info["type"],
+ "tenant_id": current_user.current_tenant_id,
},
- document_model=document.doc_form
+ document_model=document.doc_form,
)
extract_settings.append(extract_setting)
- elif document.data_source_type == 'website_crawl':
+ elif document.data_source_type == "website_crawl":
extract_setting = ExtractSetting(
datasource_type="website_crawl",
website_info={
- "provider": data_source_info['provider'],
- "job_id": data_source_info['job_id'],
- "url": data_source_info['url'],
+ "provider": data_source_info["provider"],
+ "job_id": data_source_info["job_id"],
+ "url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id,
- "mode": data_source_info['mode'],
- "only_main_content": data_source_info['only_main_content']
+ "mode": data_source_info["mode"],
+ "only_main_content": data_source_info["only_main_content"],
},
- document_model=document.doc_form
+ document_model=document.doc_form,
)
extract_settings.append(extract_setting)
else:
- raise ValueError('Data source type not support')
+ raise ValueError("Data source type not support")
indexing_runner = IndexingRunner()
try:
- response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
- data_process_rule_dict, document.doc_form,
- 'English', dataset_id)
+ response = indexing_runner.indexing_estimate(
+ current_user.current_tenant_id,
+ extract_settings,
+ data_process_rule_dict,
+ document.doc_form,
+ "English",
+ dataset_id,
+ )
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
- "in the Settings -> Model Provider.")
+ "in the Settings -> Model Provider."
+ )
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except Exception as e:
@@ -516,7 +502,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
class DocumentBatchIndexingStatusApi(DocumentResource):
-
@setup_required
@login_required
@account_initialization_required
@@ -526,24 +511,24 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
documents = self.get_batch_documents(dataset_id, batch)
documents_status = []
for document in documents:
- completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
- DocumentSegment.document_id == str(document.id),
- DocumentSegment.status != 're_segment').count()
- total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
- DocumentSegment.status != 're_segment').count()
+ completed_segments = DocumentSegment.query.filter(
+ DocumentSegment.completed_at.isnot(None),
+ DocumentSegment.document_id == str(document.id),
+ DocumentSegment.status != "re_segment",
+ ).count()
+ total_segments = DocumentSegment.query.filter(
+ DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
+ ).count()
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
- document.indexing_status = 'paused'
+ document.indexing_status = "paused"
documents_status.append(marshal(document, document_status_fields))
- data = {
- 'data': documents_status
- }
+ data = {"data": documents_status}
return data
class DocumentIndexingStatusApi(DocumentResource):
-
@setup_required
@login_required
@account_initialization_required
@@ -552,25 +537,24 @@ class DocumentIndexingStatusApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
- completed_segments = DocumentSegment.query \
- .filter(DocumentSegment.completed_at.isnot(None),
- DocumentSegment.document_id == str(document_id),
- DocumentSegment.status != 're_segment') \
- .count()
- total_segments = DocumentSegment.query \
- .filter(DocumentSegment.document_id == str(document_id),
- DocumentSegment.status != 're_segment') \
- .count()
+ completed_segments = DocumentSegment.query.filter(
+ DocumentSegment.completed_at.isnot(None),
+ DocumentSegment.document_id == str(document_id),
+ DocumentSegment.status != "re_segment",
+ ).count()
+ total_segments = DocumentSegment.query.filter(
+ DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment"
+ ).count()
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
- document.indexing_status = 'paused'
+ document.indexing_status = "paused"
return marshal(document, document_status_fields)
class DocumentDetailApi(DocumentResource):
- METADATA_CHOICES = {'all', 'only', 'without'}
+ METADATA_CHOICES = {"all", "only", "without"}
@setup_required
@login_required
@@ -580,77 +564,73 @@ class DocumentDetailApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
- metadata = request.args.get('metadata', 'all')
+ metadata = request.args.get("metadata", "all")
if metadata not in self.METADATA_CHOICES:
- raise InvalidMetadataError(f'Invalid metadata value: {metadata}')
+ raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
- if metadata == 'only':
- response = {
- 'id': document.id,
- 'doc_type': document.doc_type,
- 'doc_metadata': document.doc_metadata
- }
- elif metadata == 'without':
+ if metadata == "only":
+ response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
+ elif metadata == "without":
process_rules = DatasetService.get_process_rules(dataset_id)
data_source_info = document.data_source_detail_dict
response = {
- 'id': document.id,
- 'position': document.position,
- 'data_source_type': document.data_source_type,
- 'data_source_info': data_source_info,
- 'dataset_process_rule_id': document.dataset_process_rule_id,
- 'dataset_process_rule': process_rules,
- 'name': document.name,
- 'created_from': document.created_from,
- 'created_by': document.created_by,
- 'created_at': document.created_at.timestamp(),
- 'tokens': document.tokens,
- 'indexing_status': document.indexing_status,
- 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None,
- 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None,
- 'indexing_latency': document.indexing_latency,
- 'error': document.error,
- 'enabled': document.enabled,
- 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None,
- 'disabled_by': document.disabled_by,
- 'archived': document.archived,
- 'segment_count': document.segment_count,
- 'average_segment_length': document.average_segment_length,
- 'hit_count': document.hit_count,
- 'display_status': document.display_status,
- 'doc_form': document.doc_form
+ "id": document.id,
+ "position": document.position,
+ "data_source_type": document.data_source_type,
+ "data_source_info": data_source_info,
+ "dataset_process_rule_id": document.dataset_process_rule_id,
+ "dataset_process_rule": process_rules,
+ "name": document.name,
+ "created_from": document.created_from,
+ "created_by": document.created_by,
+ "created_at": document.created_at.timestamp(),
+ "tokens": document.tokens,
+ "indexing_status": document.indexing_status,
+ "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
+ "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
+ "indexing_latency": document.indexing_latency,
+ "error": document.error,
+ "enabled": document.enabled,
+ "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
+ "disabled_by": document.disabled_by,
+ "archived": document.archived,
+ "segment_count": document.segment_count,
+ "average_segment_length": document.average_segment_length,
+ "hit_count": document.hit_count,
+ "display_status": document.display_status,
+ "doc_form": document.doc_form,
}
else:
process_rules = DatasetService.get_process_rules(dataset_id)
data_source_info = document.data_source_detail_dict
response = {
- 'id': document.id,
- 'position': document.position,
- 'data_source_type': document.data_source_type,
- 'data_source_info': data_source_info,
- 'dataset_process_rule_id': document.dataset_process_rule_id,
- 'dataset_process_rule': process_rules,
- 'name': document.name,
- 'created_from': document.created_from,
- 'created_by': document.created_by,
- 'created_at': document.created_at.timestamp(),
- 'tokens': document.tokens,
- 'indexing_status': document.indexing_status,
- 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None,
- 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None,
- 'indexing_latency': document.indexing_latency,
- 'error': document.error,
- 'enabled': document.enabled,
- 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None,
- 'disabled_by': document.disabled_by,
- 'archived': document.archived,
- 'doc_type': document.doc_type,
- 'doc_metadata': document.doc_metadata,
- 'segment_count': document.segment_count,
- 'average_segment_length': document.average_segment_length,
- 'hit_count': document.hit_count,
- 'display_status': document.display_status,
- 'doc_form': document.doc_form
+ "id": document.id,
+ "position": document.position,
+ "data_source_type": document.data_source_type,
+ "data_source_info": data_source_info,
+ "dataset_process_rule_id": document.dataset_process_rule_id,
+ "dataset_process_rule": process_rules,
+ "name": document.name,
+ "created_from": document.created_from,
+ "created_by": document.created_by,
+ "created_at": document.created_at.timestamp(),
+ "tokens": document.tokens,
+ "indexing_status": document.indexing_status,
+ "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
+ "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
+ "indexing_latency": document.indexing_latency,
+ "error": document.error,
+ "enabled": document.enabled,
+ "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
+ "disabled_by": document.disabled_by,
+ "archived": document.archived,
+ "doc_type": document.doc_type,
+ "doc_metadata": document.doc_metadata,
+ "segment_count": document.segment_count,
+ "average_segment_length": document.average_segment_length,
+ "hit_count": document.hit_count,
+ "display_status": document.display_status,
+ "doc_form": document.doc_form,
}
return response, 200
@@ -671,7 +651,7 @@ class DocumentProcessingApi(DocumentResource):
if action == "pause":
if document.indexing_status != "indexing":
- raise InvalidActionError('Document not in indexing state.')
+ raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None)
@@ -680,7 +660,7 @@ class DocumentProcessingApi(DocumentResource):
elif action == "resume":
if document.indexing_status not in ["paused", "error"]:
- raise InvalidActionError('Document not in paused or error state.')
+ raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None
document.paused_at = None
@@ -689,7 +669,7 @@ class DocumentProcessingApi(DocumentResource):
else:
raise InvalidActionError()
- return {'result': 'success'}, 200
+ return {"result": "success"}, 200
class DocumentDeleteApi(DocumentResource):
@@ -710,9 +690,9 @@ class DocumentDeleteApi(DocumentResource):
try:
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
- raise DocumentIndexingError('Cannot delete document during indexing.')
+ raise DocumentIndexingError("Cannot delete document during indexing.")
- return {'result': 'success'}, 204
+ return {"result": "success"}, 204
class DocumentMetadataApi(DocumentResource):
@@ -726,26 +706,26 @@ class DocumentMetadataApi(DocumentResource):
req_data = request.get_json()
- doc_type = req_data.get('doc_type')
- doc_metadata = req_data.get('doc_metadata')
+ doc_type = req_data.get("doc_type")
+ doc_metadata = req_data.get("doc_metadata")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if doc_type is None or doc_metadata is None:
- raise ValueError('Both doc_type and doc_metadata must be provided.')
+ raise ValueError("Both doc_type and doc_metadata must be provided.")
if doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA:
- raise ValueError('Invalid doc_type.')
+ raise ValueError("Invalid doc_type.")
if not isinstance(doc_metadata, dict):
- raise ValueError('doc_metadata must be a dictionary.')
+ raise ValueError("doc_metadata must be a dictionary.")
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
document.doc_metadata = {}
- if doc_type == 'others':
+ if doc_type == "others":
document.doc_metadata = doc_metadata
else:
for key, value_type in metadata_schema.items():
@@ -757,14 +737,14 @@ class DocumentMetadataApi(DocumentResource):
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
- return {'result': 'success', 'message': 'Document metadata updated.'}, 200
+ return {"result": "success", "message": "Document metadata updated."}, 200
class DocumentStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
- @cloud_edition_billing_resource_check('vector_space')
+ @cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
document_id = str(document_id)
@@ -784,14 +764,14 @@ class DocumentStatusApi(DocumentResource):
document = self.get_document(dataset_id, document_id)
- indexing_cache_key = 'document_{}_indexing'.format(document.id)
+ indexing_cache_key = "document_{}_indexing".format(document.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later")
if action == "enable":
if document.enabled:
- raise InvalidActionError('Document already enabled.')
+ raise InvalidActionError("Document already enabled.")
document.enabled = True
document.disabled_at = None
@@ -804,13 +784,13 @@ class DocumentStatusApi(DocumentResource):
add_document_to_index_task.delay(document_id)
- return {'result': 'success'}, 200
+ return {"result": "success"}, 200
elif action == "disable":
- if not document.completed_at or document.indexing_status != 'completed':
- raise InvalidActionError('Document is not completed.')
+ if not document.completed_at or document.indexing_status != "completed":
+ raise InvalidActionError("Document is not completed.")
if not document.enabled:
- raise InvalidActionError('Document already disabled.')
+ raise InvalidActionError("Document already disabled.")
document.enabled = False
document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None)
@@ -823,11 +803,11 @@ class DocumentStatusApi(DocumentResource):
remove_document_from_index_task.delay(document_id)
- return {'result': 'success'}, 200
+ return {"result": "success"}, 200
elif action == "archive":
if document.archived:
- raise InvalidActionError('Document already archived.')
+ raise InvalidActionError("Document already archived.")
document.archived = True
document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None)
@@ -841,10 +821,10 @@ class DocumentStatusApi(DocumentResource):
remove_document_from_index_task.delay(document_id)
- return {'result': 'success'}, 200
+ return {"result": "success"}, 200
elif action == "un_archive":
if not document.archived:
- raise InvalidActionError('Document is not archived.')
+ raise InvalidActionError("Document is not archived.")
document.archived = False
document.archived_at = None
@@ -857,13 +837,12 @@ class DocumentStatusApi(DocumentResource):
add_document_to_index_task.delay(document_id)
- return {'result': 'success'}, 200
+ return {"result": "success"}, 200
else:
raise InvalidActionError()
class DocumentPauseApi(DocumentResource):
-
@setup_required
@login_required
@account_initialization_required
@@ -874,7 +853,7 @@ class DocumentPauseApi(DocumentResource):
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
- raise NotFound('Dataset not found.')
+ raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id)
@@ -890,9 +869,9 @@ class DocumentPauseApi(DocumentResource):
# pause document
DocumentService.pause_document(document)
except services.errors.document.DocumentIndexingError:
- raise DocumentIndexingError('Cannot pause completed document.')
+ raise DocumentIndexingError("Cannot pause completed document.")
- return {'result': 'success'}, 204
+ return {"result": "success"}, 204
class DocumentRecoverApi(DocumentResource):
@@ -905,7 +884,7 @@ class DocumentRecoverApi(DocumentResource):
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
- raise NotFound('Dataset not found.')
+ raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
@@ -919,9 +898,9 @@ class DocumentRecoverApi(DocumentResource):
# pause document
DocumentService.recover_document(document)
except services.errors.document.DocumentIndexingError:
- raise DocumentIndexingError('Document is not in paused status.')
+ raise DocumentIndexingError("Document is not in paused status.")
- return {'result': 'success'}, 204
+ return {"result": "success"}, 204
class DocumentRetryApi(DocumentResource):
@@ -932,15 +911,14 @@ class DocumentRetryApi(DocumentResource):
"""retry document."""
parser = reqparse.RequestParser()
- parser.add_argument('document_ids', type=list, required=True, nullable=False,
- location='json')
+ parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
retry_documents = []
if not dataset:
- raise NotFound('Dataset not found.')
- for document_id in args['document_ids']:
+ raise NotFound("Dataset not found.")
+ for document_id in args["document_ids"]:
try:
document_id = str(document_id)
@@ -955,7 +933,7 @@ class DocumentRetryApi(DocumentResource):
raise ArchivedDocumentImmutableError()
# 400 if document is completed
- if document.indexing_status == 'completed':
+ if document.indexing_status == "completed":
raise DocumentAlreadyFinishedError()
retry_documents.append(document)
except Exception as e:
@@ -964,7 +942,7 @@ class DocumentRetryApi(DocumentResource):
# retry document
DocumentService.retry_document(dataset_id, retry_documents)
- return {'result': 'success'}, 204
+ return {"result": "success"}, 204
class DocumentRenameApi(DocumentResource):
@@ -979,13 +957,13 @@ class DocumentRenameApi(DocumentResource):
dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser()
- parser.add_argument('name', type=str, required=True, nullable=False, location='json')
+ parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
- document = DocumentService.rename_document(dataset_id, document_id, args['name'])
+ document = DocumentService.rename_document(dataset_id, document_id, args["name"])
except services.errors.document.DocumentIndexingError:
- raise DocumentIndexingError('Cannot delete document during indexing.')
+ raise DocumentIndexingError("Cannot delete document during indexing.")
return document
@@ -999,51 +977,43 @@ class WebsiteDocumentSyncApi(DocumentResource):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
- raise NotFound('Dataset not found.')
+ raise NotFound("Dataset not found.")
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
- raise NotFound('Document not found.')
+ raise NotFound("Document not found.")
if document.tenant_id != current_user.current_tenant_id:
- raise Forbidden('No permission.')
- if document.data_source_type != 'website_crawl':
- raise ValueError('Document is not a website document.')
+ raise Forbidden("No permission.")
+ if document.data_source_type != "website_crawl":
+ raise ValueError("Document is not a website document.")
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
# sync document
DocumentService.sync_website_document(dataset_id, document)
- return {'result': 'success'}, 200
+ return {"result": "success"}, 200
-api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
-api.add_resource(DatasetDocumentListApi,
- '/datasets//documents')
-api.add_resource(DatasetInitApi,
- '/datasets/init')
-api.add_resource(DocumentIndexingEstimateApi,
- '/datasets//documents//indexing-estimate')
-api.add_resource(DocumentBatchIndexingEstimateApi,
- '/datasets//batch//indexing-estimate')
-api.add_resource(DocumentBatchIndexingStatusApi,
- '/datasets//batch//indexing-status')
-api.add_resource(DocumentIndexingStatusApi,
- '/datasets//documents/