diff --git a/api/.idea/vcs.xml b/api/.idea/vcs.xml index eaa7c25c60..b7af618884 100644 --- a/api/.idea/vcs.xml +++ b/api/.idea/vcs.xml @@ -12,5 +12,6 @@ + diff --git a/api/Dockerfile b/api/Dockerfile index 06a6f43631..cca6488679 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -5,6 +5,10 @@ WORKDIR /app/api # Install Poetry ENV POETRY_VERSION=1.8.3 + +# if you located in China, you can use aliyun mirror to speed up +# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/ + RUN pip install --no-cache-dir poetry==${POETRY_VERSION} # Configure Poetry @@ -16,6 +20,9 @@ ENV POETRY_REQUESTS_TIMEOUT=15 FROM base AS packages +# if you located in China, you can use aliyun mirror to speed up +# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources + RUN apt-get update \ && apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev @@ -43,6 +50,8 @@ WORKDIR /app/api RUN apt-get update \ && apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \ + # if you located in China, you can use aliyun mirror to speed up + # && echo "deb http://mirrors.aliyun.com/debian testing main" > /etc/apt/sources.list \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security @@ -56,7 +65,7 @@ COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV} ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" # Download nltk data -RUN python -c "import nltk; nltk.download('punkt')" +RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')" # Copy source code COPY . /app/api/ diff --git a/api/controllers/__init__.py b/api/controllers/__init__.py index b28b04f643..8b13789179 100644 --- a/api/controllers/__init__.py +++ b/api/controllers/__init__.py @@ -1,3 +1 @@ - - diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index b2b9d8d496..eb7c1464d3 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -2,7 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('console', __name__, url_prefix='/console/api') +bp = Blueprint("console", __name__, url_prefix="/console/api") api = ExternalApi(bp) # Import other controllers diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 028be5de54..a4ceec2662 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -15,24 +15,24 @@ from models.model import App, InstalledApp, RecommendedApp def admin_required(view): @wraps(view) def decorated(*args, **kwargs): - if not os.getenv('ADMIN_API_KEY'): - raise Unauthorized('API key is invalid.') + if not os.getenv("ADMIN_API_KEY"): + raise Unauthorized("API key is invalid.") - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if auth_header is None: - raise Unauthorized('Authorization header is missing.') + raise Unauthorized("Authorization header is missing.") - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - if os.getenv('ADMIN_API_KEY') != auth_token: - raise Unauthorized('API key is invalid.') + if os.getenv("ADMIN_API_KEY") != auth_token: + raise Unauthorized("API key is invalid.") return view(*args, **kwargs) @@ -44,37 +44,41 @@ class InsertExploreAppListApi(Resource): @admin_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, nullable=False, location='json') - parser.add_argument('desc', type=str, location='json') - parser.add_argument('copyright', type=str, location='json') - parser.add_argument('privacy_policy', type=str, location='json') - parser.add_argument('custom_disclaimer', type=str, location='json') - parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json') - parser.add_argument('category', type=str, required=True, nullable=False, location='json') - parser.add_argument('position', type=int, required=True, nullable=False, location='json') + parser.add_argument("app_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("desc", type=str, location="json") + parser.add_argument("copyright", type=str, location="json") + parser.add_argument("privacy_policy", type=str, location="json") + parser.add_argument("custom_disclaimer", type=str, location="json") + parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json") + parser.add_argument("category", type=str, required=True, nullable=False, location="json") + parser.add_argument("position", type=int, required=True, nullable=False, location="json") args = parser.parse_args() - app = App.query.filter(App.id == args['app_id']).first() + app = App.query.filter(App.id == args["app_id"]).first() if not app: raise NotFound(f'App \'{args["app_id"]}\' is not found') site = app.site if not site: - desc = args['desc'] if args['desc'] else '' - copy_right = args['copyright'] if args['copyright'] else '' - privacy_policy = args['privacy_policy'] if args['privacy_policy'] else '' - custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else '' + desc = args["desc"] if args["desc"] else "" + copy_right = args["copyright"] if args["copyright"] else "" + privacy_policy = args["privacy_policy"] if args["privacy_policy"] else "" + custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else "" else: - desc = site.description if site.description else \ - args['desc'] if args['desc'] else '' - copy_right = site.copyright if site.copyright else \ - args['copyright'] if args['copyright'] else '' - privacy_policy = site.privacy_policy if site.privacy_policy else \ - args['privacy_policy'] if args['privacy_policy'] else '' - custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \ - args['custom_disclaimer'] if args['custom_disclaimer'] else '' + desc = site.description if site.description else args["desc"] if args["desc"] else "" + copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else "" + privacy_policy = ( + site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else "" + ) + custom_disclaimer = ( + site.custom_disclaimer + if site.custom_disclaimer + else args["custom_disclaimer"] + if args["custom_disclaimer"] + else "" + ) - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() if not recommended_app: recommended_app = RecommendedApp( @@ -83,9 +87,9 @@ class InsertExploreAppListApi(Resource): copyright=copy_right, privacy_policy=privacy_policy, custom_disclaimer=custom_disclaimer, - language=args['language'], - category=args['category'], - position=args['position'] + language=args["language"], + category=args["category"], + position=args["position"], ) db.session.add(recommended_app) @@ -93,21 +97,21 @@ class InsertExploreAppListApi(Resource): app.is_public = True db.session.commit() - return {'result': 'success'}, 201 + return {"result": "success"}, 201 else: recommended_app.description = desc recommended_app.copyright = copy_right recommended_app.privacy_policy = privacy_policy recommended_app.custom_disclaimer = custom_disclaimer - recommended_app.language = args['language'] - recommended_app.category = args['category'] - recommended_app.position = args['position'] + recommended_app.language = args["language"] + recommended_app.category = args["category"] + recommended_app.position = args["position"] app.is_public = True db.session.commit() - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class InsertExploreAppApi(Resource): @@ -116,15 +120,14 @@ class InsertExploreAppApi(Resource): def delete(self, app_id): recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first() if not recommended_app: - return {'result': 'success'}, 204 + return {"result": "success"}, 204 app = App.query.filter(App.id == recommended_app.app_id).first() if app: app.is_public = False installed_apps = InstalledApp.query.filter( - InstalledApp.app_id == recommended_app.app_id, - InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id + InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id ).all() for installed_app in installed_apps: @@ -133,8 +136,8 @@ class InsertExploreAppApi(Resource): db.session.delete(recommended_app) db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 -api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps') -api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/') +api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps") +api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/") diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 324b831175..3f5e1adca2 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -14,26 +14,21 @@ from .setup import setup_required from .wraps import account_initialization_required api_key_fields = { - 'id': fields.String, - 'type': fields.String, - 'token': fields.String, - 'last_used_at': TimestampField, - 'created_at': TimestampField + "id": fields.String, + "type": fields.String, + "token": fields.String, + "last_used_at": TimestampField, + "created_at": TimestampField, } -api_key_list = { - 'data': fields.List(fields.Nested(api_key_fields), attribute="items") -} +api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")} def _get_resource(resource_id, tenant_id, resource_model): - resource = resource_model.query.filter_by( - id=resource_id, tenant_id=tenant_id - ).first() + resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first() if resource is None: - flask_restful.abort( - 404, message=f"{resource_model.__name__} not found.") + flask_restful.abort(404, message=f"{resource_model.__name__} not found.") return resource @@ -50,30 +45,32 @@ class BaseApiKeyListResource(Resource): @marshal_with(api_key_list) def get(self, resource_id): resource_id = str(resource_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) - keys = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ - all() + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) + keys = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .all() + ) return {"items": keys} @marshal_with(api_key_fields) def post(self, resource_id): resource_id = str(resource_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) if not current_user.is_admin_or_owner: raise Forbidden() - current_key_count = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ - count() + current_key_count = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .count() + ) if current_key_count >= self.max_keys: flask_restful.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", - code='max_keys_exceeded' + code="max_keys_exceeded", ) key = ApiToken.generate_api_key(self.token_prefix, 24) @@ -97,79 +94,78 @@ class BaseApiKeyResource(Resource): def delete(self, resource_id, api_key_id): resource_id = str(resource_id) api_key_id = str(api_key_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - key = db.session.query(ApiToken). \ - filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \ - first() + key = ( + db.session.query(ApiToken) + .filter( + getattr(ApiToken, self.resource_id_field) == resource_id, + ApiToken.type == self.resource_type, + ApiToken.id == api_key_id, + ) + .first() + ) if key is None: - flask_restful.abort(404, message='API key not found') + flask_restful.abort(404, message="API key not found") db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class AppApiKeyListResource(BaseApiKeyListResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'app' + resource_type = "app" resource_model = App - resource_id_field = 'app_id' - token_prefix = 'app-' + resource_id_field = "app_id" + token_prefix = "app-" class AppApiKeyResource(BaseApiKeyResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'app' + resource_type = "app" resource_model = App - resource_id_field = 'app_id' + resource_id_field = "app_id" class DatasetApiKeyListResource(BaseApiKeyListResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'dataset' + resource_type = "dataset" resource_model = Dataset - resource_id_field = 'dataset_id' - token_prefix = 'ds-' + resource_id_field = "dataset_id" + token_prefix = "ds-" class DatasetApiKeyResource(BaseApiKeyResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'dataset' + + resource_type = "dataset" resource_model = Dataset - resource_id_field = 'dataset_id' + resource_id_field = "dataset_id" -api.add_resource(AppApiKeyListResource, '/apps//api-keys') -api.add_resource(AppApiKeyResource, - '/apps//api-keys/') -api.add_resource(DatasetApiKeyListResource, - '/datasets//api-keys') -api.add_resource(DatasetApiKeyResource, - '/datasets//api-keys/') +api.add_resource(AppApiKeyListResource, "/apps//api-keys") +api.add_resource(AppApiKeyResource, "/apps//api-keys/") +api.add_resource(DatasetApiKeyListResource, "/datasets//api-keys") +api.add_resource(DatasetApiKeyResource, "/datasets//api-keys/") diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index fa2b3807e8..e7346bdf1d 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -8,19 +8,18 @@ from services.advanced_prompt_template_service import AdvancedPromptTemplateServ class AdvancedPromptTemplateList(Resource): - @setup_required @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument('app_mode', type=str, required=True, location='args') - parser.add_argument('model_mode', type=str, required=True, location='args') - parser.add_argument('has_context', type=str, required=False, default='true', location='args') - parser.add_argument('model_name', type=str, required=True, location='args') + parser.add_argument("app_mode", type=str, required=True, location="args") + parser.add_argument("model_mode", type=str, required=True, location="args") + parser.add_argument("has_context", type=str, required=False, default="true", location="args") + parser.add_argument("model_name", type=str, required=True, location="args") args = parser.parse_args() return AdvancedPromptTemplateService.get_prompt(args) -api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates') \ No newline at end of file + +api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates") diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index aee367276c..51899da705 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -18,15 +18,12 @@ class AgentLogApi(Resource): def get(self, app_model): """Get agent logs""" parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='args') - parser.add_argument('conversation_id', type=uuid_value, required=True, location='args') + parser.add_argument("message_id", type=uuid_value, required=True, location="args") + parser.add_argument("conversation_id", type=uuid_value, required=True, location="args") args = parser.parse_args() - return AgentService.get_agent_logs( - app_model, - args['conversation_id'], - args['message_id'] - ) - -api.add_resource(AgentLogApi, '/apps//agent/logs') \ No newline at end of file + return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) + + +api.add_resource(AgentLogApi, "/apps//agent/logs") diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index bc15919a99..1ea1c82679 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -21,23 +21,23 @@ class AnnotationReplyActionApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def post(self, app_id, action): if not current_user.is_editor: raise Forbidden() app_id = str(app_id) parser = reqparse.RequestParser() - parser.add_argument('score_threshold', required=True, type=float, location='json') - parser.add_argument('embedding_provider_name', required=True, type=str, location='json') - parser.add_argument('embedding_model_name', required=True, type=str, location='json') + parser.add_argument("score_threshold", required=True, type=float, location="json") + parser.add_argument("embedding_provider_name", required=True, type=str, location="json") + parser.add_argument("embedding_model_name", required=True, type=str, location="json") args = parser.parse_args() - if action == 'enable': + if action == "enable": result = AppAnnotationService.enable_app_annotation(args, app_id) - elif action == 'disable': + elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_id) else: - raise ValueError('Unsupported annotation reply action') + raise ValueError("Unsupported annotation reply action") return result, 200 @@ -66,7 +66,7 @@ class AppAnnotationSettingUpdateApi(Resource): annotation_setting_id = str(annotation_setting_id) parser = reqparse.RequestParser() - parser.add_argument('score_threshold', required=True, type=float, location='json') + parser.add_argument("score_threshold", required=True, type=float, location="json") args = parser.parse_args() result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) @@ -77,28 +77,24 @@ class AnnotationReplyActionStatusApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id, action): if not current_user.is_editor: raise Forbidden() job_id = str(job_id) - app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id)) + app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) cache_result = redis_client.get(app_annotation_job_key) if cache_result is None: raise ValueError("The job is not exist.") job_status = cache_result.decode() - error_msg = '' - if job_status == 'error': - app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id)) + error_msg = "" + if job_status == "error": + app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id)) error_msg = redis_client.get(app_annotation_error_key).decode() - return { - 'job_id': job_id, - 'job_status': job_status, - 'error_msg': error_msg - }, 200 + return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 class AnnotationListApi(Resource): @@ -109,18 +105,18 @@ class AnnotationListApi(Resource): if not current_user.is_editor: raise Forbidden() - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - keyword = request.args.get('keyword', default=None, type=str) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + keyword = request.args.get("keyword", default=None, type=str) app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) response = { - 'data': marshal(annotation_list, annotation_fields), - 'has_more': len(annotation_list) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(annotation_list, annotation_fields), + "has_more": len(annotation_list) == limit, + "limit": limit, + "total": total, + "page": page, } return response, 200 @@ -135,9 +131,7 @@ class AnnotationExportApi(Resource): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response = { - 'data': marshal(annotation_list, annotation_fields) - } + response = {"data": marshal(annotation_list, annotation_fields)} return response, 200 @@ -145,7 +139,7 @@ class AnnotationCreateApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id): if not current_user.is_editor: @@ -153,8 +147,8 @@ class AnnotationCreateApi(Resource): app_id = str(app_id) parser = reqparse.RequestParser() - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") args = parser.parse_args() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) return annotation @@ -164,7 +158,7 @@ class AnnotationUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id, annotation_id): if not current_user.is_editor: @@ -173,8 +167,8 @@ class AnnotationUpdateDeleteApi(Resource): app_id = str(app_id) annotation_id = str(annotation_id) parser = reqparse.RequestParser() - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") args = parser.parse_args() annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) return annotation @@ -189,29 +183,29 @@ class AnnotationUpdateDeleteApi(Resource): app_id = str(app_id) annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_id, annotation_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class AnnotationBatchImportApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def post(self, app_id): if not current_user.is_editor: raise Forbidden() app_id = str(app_id) # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") return AppAnnotationService.batch_import_app_annotations(app_id, file) @@ -220,27 +214,23 @@ class AnnotationBatchImportStatusApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id): if not current_user.is_editor: raise Forbidden() job_id = str(job_id) - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) cache_result = redis_client.get(indexing_cache_key) if cache_result is None: raise ValueError("The job is not exist.") job_status = cache_result.decode() - error_msg = '' - if job_status == 'error': - indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) + error_msg = "" + if job_status == "error": + indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) error_msg = redis_client.get(indexing_error_msg_key).decode() - return { - 'job_id': job_id, - 'job_status': job_status, - 'error_msg': error_msg - }, 200 + return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 class AnnotationHitHistoryListApi(Resource): @@ -251,30 +241,32 @@ class AnnotationHitHistoryListApi(Resource): if not current_user.is_editor: raise Forbidden() - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) app_id = str(app_id) annotation_id = str(annotation_id) - annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id, - page, limit) + annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( + app_id, annotation_id, page, limit + ) response = { - 'data': marshal(annotation_hit_history_list, annotation_hit_history_fields), - 'has_more': len(annotation_hit_history_list) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(annotation_hit_history_list, annotation_hit_history_fields), + "has_more": len(annotation_hit_history_list) == limit, + "limit": limit, + "total": total, + "page": page, } return response -api.add_resource(AnnotationReplyActionApi, '/apps//annotation-reply/') -api.add_resource(AnnotationReplyActionStatusApi, - '/apps//annotation-reply//status/') -api.add_resource(AnnotationListApi, '/apps//annotations') -api.add_resource(AnnotationExportApi, '/apps//annotations/export') -api.add_resource(AnnotationUpdateDeleteApi, '/apps//annotations/') -api.add_resource(AnnotationBatchImportApi, '/apps//annotations/batch-import') -api.add_resource(AnnotationBatchImportStatusApi, '/apps//annotations/batch-import-status/') -api.add_resource(AnnotationHitHistoryListApi, '/apps//annotations//hit-histories') -api.add_resource(AppAnnotationSettingDetailApi, '/apps//annotation-setting') -api.add_resource(AppAnnotationSettingUpdateApi, '/apps//annotation-settings/') +api.add_resource(AnnotationReplyActionApi, "/apps//annotation-reply/") +api.add_resource( + AnnotationReplyActionStatusApi, "/apps//annotation-reply//status/" +) +api.add_resource(AnnotationListApi, "/apps//annotations") +api.add_resource(AnnotationExportApi, "/apps//annotations/export") +api.add_resource(AnnotationUpdateDeleteApi, "/apps//annotations/") +api.add_resource(AnnotationBatchImportApi, "/apps//annotations/batch-import") +api.add_resource(AnnotationBatchImportStatusApi, "/apps//annotations/batch-import-status/") +api.add_resource(AnnotationHitHistoryListApi, "/apps//annotations//hit-histories") +api.add_resource(AppAnnotationSettingDetailApi, "/apps//annotation-setting") +api.add_resource(AppAnnotationSettingUpdateApi, "/apps//annotation-settings/") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 8651597fd7..cc9c8b31cb 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -18,27 +18,35 @@ from libs.login import login_required from services.app_dsl_service import AppDslService from services.app_service import AppService -ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion'] +ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] class AppListApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): """Get app list""" + def uuid_list(value): try: - return [str(uuid.UUID(v)) for v in value.split(',')] + return [str(uuid.UUID(v)) for v in value.split(",")] except ValueError: abort(400, message="Invalid UUID format in tag_ids.") + parser = reqparse.RequestParser() - parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False) - parser.add_argument('name', type=str, location='args', required=False) - parser.add_argument('tag_ids', type=uuid_list, location='args', required=False) + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "mode", + type=str, + choices=["chat", "workflow", "agent-chat", "channel", "all"], + default="all", + location="args", + required=False, + ) + parser.add_argument("name", type=str, location="args", required=False) + parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) args = parser.parse_args() @@ -46,7 +54,7 @@ class AppListApi(Resource): app_service = AppService() app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) if not app_pagination: - return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False} + return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return marshal(app_pagination, app_pagination_fields) @@ -54,23 +62,23 @@ class AppListApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Create app""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json') - parser.add_argument('icon_type', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if 'mode' not in args or args['mode'] is None: + if "mode" not in args or args["mode"] is None: raise BadRequest("mode is required") app_service = AppService() @@ -84,7 +92,7 @@ class AppImportApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Import app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -92,19 +100,16 @@ class AppImportApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('data', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon_type', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("data", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, - data=args['data'], - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user ) return app, 201 @@ -115,7 +120,7 @@ class AppImportFromUrlApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Import app from url""" # The role of the current user in the ta table must be admin, owner, or editor @@ -123,25 +128,21 @@ class AppImportFromUrlApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('url', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("url", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app = AppDslService.import_and_create_new_app_from_url( - tenant_id=current_user.current_tenant_id, - url=args['url'], - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user ) return app, 201 class AppApi(Resource): - @setup_required @login_required @account_initialization_required @@ -165,14 +166,14 @@ class AppApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon_type', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') - parser.add_argument('max_active_requests', type=int, location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") + parser.add_argument("max_active_requests", type=int, location="json") args = parser.parse_args() app_service = AppService() @@ -193,7 +194,7 @@ class AppApi(Resource): app_service = AppService() app_service.delete_app(app_model) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class AppCopyApi(Resource): @@ -209,19 +210,16 @@ class AppCopyApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon_type', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() data = AppDslService.export_dsl(app_model=app_model, include_secret=True) app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, - data=data, - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user ) return app, 201 @@ -240,12 +238,10 @@ class AppExportApi(Resource): # Add include_secret params parser = reqparse.RequestParser() - parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args') + parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") args = parser.parse_args() - return { - "data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret']) - } + return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])} class AppNameApi(Resource): @@ -258,13 +254,13 @@ class AppNameApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_name(app_model, args.get('name')) + app_model = app_service.update_app_name(app_model, args.get("name")) return app_model @@ -279,14 +275,14 @@ class AppIconApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background')) + app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) return app_model @@ -301,13 +297,13 @@ class AppSiteStatus(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('enable_site', type=bool, required=True, location='json') + parser.add_argument("enable_site", type=bool, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_site_status(app_model, args.get('enable_site')) + app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) return app_model @@ -322,13 +318,13 @@ class AppApiStatus(Resource): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('enable_api', type=bool, required=True, location='json') + parser.add_argument("enable_api", type=bool, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_api_status(app_model, args.get('enable_api')) + app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) return app_model @@ -339,9 +335,7 @@ class AppTraceApi(Resource): @account_initialization_required def get(self, app_id): """Get app trace""" - app_trace_config = OpsTraceManager.get_app_tracing_config( - app_id=app_id - ) + app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id) return app_trace_config @@ -353,27 +347,27 @@ class AppTraceApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('enabled', type=bool, required=True, location='json') - parser.add_argument('tracing_provider', type=str, required=True, location='json') + parser.add_argument("enabled", type=bool, required=True, location="json") + parser.add_argument("tracing_provider", type=str, required=True, location="json") args = parser.parse_args() OpsTraceManager.update_app_tracing_config( app_id=app_id, - enabled=args['enabled'], - tracing_provider=args['tracing_provider'], + enabled=args["enabled"], + tracing_provider=args["tracing_provider"], ) return {"result": "success"} -api.add_resource(AppListApi, '/apps') -api.add_resource(AppImportApi, '/apps/import') -api.add_resource(AppImportFromUrlApi, '/apps/import/url') -api.add_resource(AppApi, '/apps/') -api.add_resource(AppCopyApi, '/apps//copy') -api.add_resource(AppExportApi, '/apps//export') -api.add_resource(AppNameApi, '/apps//name') -api.add_resource(AppIconApi, '/apps//icon') -api.add_resource(AppSiteStatus, '/apps//site-enable') -api.add_resource(AppApiStatus, '/apps//api-enable') -api.add_resource(AppTraceApi, '/apps//trace') +api.add_resource(AppListApi, "/apps") +api.add_resource(AppImportApi, "/apps/import") +api.add_resource(AppImportFromUrlApi, "/apps/import/url") +api.add_resource(AppApi, "/apps/") +api.add_resource(AppCopyApi, "/apps//copy") +api.add_resource(AppExportApi, "/apps//export") +api.add_resource(AppNameApi, "/apps//name") +api.add_resource(AppIconApi, "/apps//icon") +api.add_resource(AppSiteStatus, "/apps//site-enable") +api.add_resource(AppApiStatus, "/apps//api-enable") +api.add_resource(AppTraceApi, "/apps//trace") diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 1de08afa4e..437a6a7b38 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -39,7 +39,7 @@ class ChatMessageAudioApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def post(self, app_model): - file = request.files['file'] + file = request.files["file"] try: response = AudioService.transcript_asr( @@ -85,31 +85,31 @@ class ChatMessageTextApi(Resource): try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get( - 'voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None - response = AudioService.transcript_tts( - app_model=app_model, - text=text, - message_id=message_id, - voice=voice - ) + response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice) return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -145,12 +145,12 @@ class TextModesApi(Resource): def get(self, app_model): try: parser = reqparse.RequestParser() - parser.add_argument('language', type=str, required=True, location='args') + parser.add_argument("language", type=str, required=True, location="args") args = parser.parse_args() response = AudioService.transcript_tts_voices( tenant_id=app_model.tenant_id, - language=args['language'], + language=args["language"], ) return response @@ -179,6 +179,6 @@ class TextModesApi(Resource): raise InternalServerError() -api.add_resource(ChatMessageAudioApi, '/apps//audio-to-text') -api.add_resource(ChatMessageTextApi, '/apps//text-to-audio') -api.add_resource(TextModesApi, '/apps//text-to-audio/voices') +api.add_resource(ChatMessageAudioApi, "/apps//audio-to-text") +api.add_resource(ChatMessageTextApi, "/apps//text-to-audio") +api.add_resource(TextModesApi, "/apps//text-to-audio/voices") diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 61582536fd..6fe52ec28a 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -35,33 +35,28 @@ from services.app_generate_service import AppGenerateService # define completion message api for user class CompletionMessageApi(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('model_config', type=dict, required=True, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("model_config", type=dict, required=True, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] != 'blocking' - args['auto_generate_name'] = False + streaming = args["response_mode"] != "blocking" + args["auto_generate_name"] = False account = flask_login.current_user try: response = AppGenerateService.generate( - app_model=app_model, - user=account, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming + app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -97,7 +92,7 @@ class CompletionMessageStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatMessageApi(Resource): @@ -107,27 +102,23 @@ class ChatMessageApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('model_config', type=dict, required=True, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("model_config", type=dict, required=True, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] != 'blocking' - args['auto_generate_name'] = False + streaming = args["response_mode"] != "blocking" + args["auto_generate_name"] = False account = flask_login.current_user try: response = AppGenerateService.generate( - app_model=app_model, - user=account, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming + app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -163,10 +154,10 @@ class ChatMessageStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionMessageApi, '/apps//completion-messages') -api.add_resource(CompletionMessageStopApi, '/apps//completion-messages//stop') -api.add_resource(ChatMessageApi, '/apps//chat-messages') -api.add_resource(ChatMessageStopApi, '/apps//chat-messages//stop') +api.add_resource(CompletionMessageApi, "/apps//completion-messages") +api.add_resource(CompletionMessageStopApi, "/apps//completion-messages//stop") +api.add_resource(ChatMessageApi, "/apps//chat-messages") +api.add_resource(ChatMessageStopApi, "/apps//chat-messages//stop") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 995264541f..753a6be20c 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -26,7 +26,6 @@ from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotat class CompletionConversationApi(Resource): - @setup_required @login_required @account_initialization_required @@ -36,24 +35,23 @@ class CompletionConversationApi(Resource): if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('annotation_status', type=str, - choices=['annotated', 'not_annotated', 'all'], default='all', location='args') - parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument( + "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" + ) + parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() - query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion') + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") - if args['keyword']: - query = query.join( - Message, Message.conversation_id == Conversation.id - ).filter( + if args["keyword"]: + query = query.join(Message, Message.conversation_id == Conversation.id).filter( or_( - Message.query.ilike('%{}%'.format(args['keyword'])), - Message.answer.ilike('%{}%'.format(args['keyword'])) + Message.query.ilike("%{}%".format(args["keyword"])), + Message.answer.ilike("%{}%".format(args["keyword"])), ) ) @@ -61,8 +59,8 @@ class CompletionConversationApi(Resource): timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) @@ -70,8 +68,8 @@ class CompletionConversationApi(Resource): query = query.where(Conversation.created_at >= start_datetime_utc) - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=59) end_datetime_timezone = timezone.localize(end_datetime) @@ -79,29 +77,25 @@ class CompletionConversationApi(Resource): query = query.where(Conversation.created_at < end_datetime_utc) - if args['annotation_status'] == "annotated": + if args["annotation_status"] == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) - elif args['annotation_status'] == "not_annotated": - query = query.outerjoin( - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) + elif args["annotation_status"] == "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return conversations class CompletionConversationDetailApi(Resource): - @setup_required @login_required @account_initialization_required @@ -123,8 +117,11 @@ class CompletionConversationDetailApi(Resource): raise Forbidden() conversation_id = str(conversation_id) - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") @@ -132,11 +129,10 @@ class CompletionConversationDetailApi(Resource): conversation.is_deleted = True db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ChatConversationApi(Resource): - @setup_required @login_required @account_initialization_required @@ -146,22 +142,28 @@ class ChatConversationApi(Resource): if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('annotation_status', type=str, - choices=['annotated', 'not_annotated', 'all'], default='all', location='args') - parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args') - parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'], - required=False, default='-updated_at', location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument( + "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" + ) + parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args") + parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() subquery = ( db.session.query( - Conversation.id.label('conversation_id'), - EndUser.session_id.label('from_end_user_session_id') + Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") ) .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .subquery() @@ -169,28 +171,31 @@ class ChatConversationApi(Resource): query = db.select(Conversation).where(Conversation.app_id == app_model.id) - if args['keyword']: - keyword_filter = '%{}%'.format(args['keyword']) - query = query.join( - Message, Message.conversation_id == Conversation.id, - ).join( - subquery, subquery.c.conversation_id == Conversation.id - ).filter( - or_( - Message.query.ilike(keyword_filter), - Message.answer.ilike(keyword_filter), - Conversation.name.ilike(keyword_filter), - Conversation.introduction.ilike(keyword_filter), - subquery.c.from_end_user_session_id.ilike(keyword_filter) - ), + if args["keyword"]: + keyword_filter = "%{}%".format(args["keyword"]) + query = ( + query.join( + Message, + Message.conversation_id == Conversation.id, + ) + .join(subquery, subquery.c.conversation_id == Conversation.id) + .filter( + or_( + Message.query.ilike(keyword_filter), + Message.answer.ilike(keyword_filter), + Conversation.name.ilike(keyword_filter), + Conversation.introduction.ilike(keyword_filter), + subquery.c.from_end_user_session_id.ilike(keyword_filter), + ), + ) ) account = current_user timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) @@ -198,8 +203,8 @@ class ChatConversationApi(Resource): query = query.where(Conversation.created_at >= start_datetime_utc) - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=59) end_datetime_timezone = timezone.localize(end_datetime) @@ -207,50 +212,46 @@ class ChatConversationApi(Resource): query = query.where(Conversation.created_at < end_datetime_utc) - if args['annotation_status'] == "annotated": + if args["annotation_status"] == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) - elif args['annotation_status'] == "not_annotated": - query = query.outerjoin( - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) + elif args["annotation_status"] == "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) - if args['message_count_gte'] and args['message_count_gte'] >= 1: + if args["message_count_gte"] and args["message_count_gte"] >= 1: query = ( query.options(joinedload(Conversation.messages)) .join(Message, Message.conversation_id == Conversation.id) .group_by(Conversation.id) - .having(func.count(Message.id) >= args['message_count_gte']) + .having(func.count(Message.id) >= args["message_count_gte"]) ) if app_model.mode == AppMode.ADVANCED_CHAT.value: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) - match args['sort_by']: - case 'created_at': + match args["sort_by"]: + case "created_at": query = query.order_by(Conversation.created_at.asc()) - case '-created_at': + case "-created_at": query = query.order_by(Conversation.created_at.desc()) - case 'updated_at': + case "updated_at": query = query.order_by(Conversation.updated_at.asc()) - case '-updated_at': + case "-updated_at": query = query.order_by(Conversation.updated_at.desc()) case _: query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return conversations class ChatConversationDetailApi(Resource): - @setup_required @login_required @account_initialization_required @@ -272,8 +273,11 @@ class ChatConversationDetailApi(Resource): raise Forbidden() conversation_id = str(conversation_id) - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") @@ -281,18 +285,21 @@ class ChatConversationDetailApi(Resource): conversation.is_deleted = True db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 -api.add_resource(CompletionConversationApi, '/apps//completion-conversations') -api.add_resource(CompletionConversationDetailApi, '/apps//completion-conversations/') -api.add_resource(ChatConversationApi, '/apps//chat-conversations') -api.add_resource(ChatConversationDetailApi, '/apps//chat-conversations/') +api.add_resource(CompletionConversationApi, "/apps//completion-conversations") +api.add_resource(CompletionConversationDetailApi, "/apps//completion-conversations/") +api.add_resource(ChatConversationApi, "/apps//chat-conversations") +api.add_resource(ChatConversationDetailApi, "/apps//chat-conversations/") def _get_conversation(app_model, conversation_id): - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index aa0722ea35..23b234dac9 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -21,7 +21,7 @@ class ConversationVariablesApi(Resource): @marshal_with(paginated_conversation_variable_fields) def get(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('conversation_id', type=str, location='args') + parser.add_argument("conversation_id", type=str, location="args") args = parser.parse_args() stmt = ( @@ -29,10 +29,10 @@ class ConversationVariablesApi(Resource): .where(ConversationVariable.app_id == app_model.id) .order_by(ConversationVariable.created_at) ) - if args['conversation_id']: - stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id']) + if args["conversation_id"]: + stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"]) else: - raise ValueError('conversation_id is required') + raise ValueError("conversation_id is required") # NOTE: This is a temporary solution to avoid performance issues. page = 1 @@ -43,14 +43,14 @@ class ConversationVariablesApi(Resource): rows = session.scalars(stmt).all() return { - 'page': page, - 'limit': page_size, - 'total': len(rows), - 'has_more': False, - 'data': [ + "page": page, + "limit": page_size, + "total": len(rows), + "has_more": False, + "data": [ { - 'created_at': row.created_at, - 'updated_at': row.updated_at, + "created_at": row.created_at, + "updated_at": row.updated_at, **row.to_variable().model_dump(), } for row in rows @@ -58,4 +58,4 @@ class ConversationVariablesApi(Resource): } -api.add_resource(ConversationVariablesApi, '/apps//conversation-variables') +api.add_resource(ConversationVariablesApi, "/apps//conversation-variables") diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index f6feed1221..33d30c2051 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -2,116 +2,120 @@ from libs.exception import BaseHTTPException class AppNotFoundError(BaseHTTPException): - error_code = 'app_not_found' + error_code = "app_not_found" description = "App not found." code = 404 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted Model Provider has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class AppMoreLikeThisDisabledError(BaseHTTPException): - error_code = 'app_more_like_this_disabled' + error_code = "app_more_like_this_disabled" description = "The 'More like this' feature is disabled. Please refresh your page." code = 403 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class DraftWorkflowNotExist(BaseHTTPException): - error_code = 'draft_workflow_not_exist' + error_code = "draft_workflow_not_exist" description = "Draft workflow need to be initialized." code = 400 class DraftWorkflowNotSync(BaseHTTPException): - error_code = 'draft_workflow_not_sync' + error_code = "draft_workflow_not_sync" description = "Workflow graph might have been modified, please refresh and resubmit." code = 400 class TracingConfigNotExist(BaseHTTPException): - error_code = 'trace_config_not_exist' + error_code = "trace_config_not_exist" description = "Trace config not exist." code = 400 class TracingConfigIsExist(BaseHTTPException): - error_code = 'trace_config_is_exist' + error_code = "trace_config_is_exist" description = "Trace config is exist." code = 400 class TracingConfigCheckError(BaseHTTPException): - error_code = 'trace_config_check_error' + error_code = "trace_config_check_error" description = "Invalid Credentials." code = 400 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 6803775e20..3d1e6b7a37 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -24,21 +24,21 @@ class RuleGenerateApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('instruction', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json') - parser.add_argument('no_variable', type=bool, required=True, default=False, location='json') + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") args = parser.parse_args() account = current_user - PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512')) + PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512")) try: rules = LLMGenerator.generate_rule_config( tenant_id=account.current_tenant_id, - instruction=args['instruction'], - model_config=args['model_config'], - no_variable=args['no_variable'], - rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS + instruction=args["instruction"], + model_config=args["model_config"], + no_variable=args["no_variable"], + rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -52,4 +52,4 @@ class RuleGenerateApi(Resource): return rules -api.add_resource(RuleGenerateApi, '/rule-generate') +api.add_resource(RuleGenerateApi, "/rule-generate") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 056415f19a..fe06201982 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -33,9 +33,9 @@ from services.message_service import MessageService class ChatMessageListApi(Resource): message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_detail_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_detail_fields)), } @setup_required @@ -45,55 +45,69 @@ class ChatMessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - conversation = db.session.query(Conversation).filter( - Conversation.id == args['conversation_id'], - Conversation.app_id == app_model.id - ).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") - if args['first_id']: - first_message = db.session.query(Message) \ - .filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first() + if args["first_id"]: + first_message = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"]) + .first() + ) if not first_message: raise NotFound("First message not found") - history_messages = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < first_message.created_at, - Message.id != first_message.id - ) \ - .order_by(Message.created_at.desc()).limit(args['limit']).all() + history_messages = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < first_message.created_at, + Message.id != first_message.id, + ) + .order_by(Message.created_at.desc()) + .limit(args["limit"]) + .all() + ) else: - history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ - .order_by(Message.created_at.desc()).limit(args['limit']).all() + history_messages = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id) + .order_by(Message.created_at.desc()) + .limit(args["limit"]) + .all() + ) has_more = False - if len(history_messages) == args['limit']: + if len(history_messages) == args["limit"]: current_page_first_message = history_messages[-1] - rest_count = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id - ).count() + rest_count = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id, + ) + .count() + ) if rest_count > 0: has_more = True history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination( - data=history_messages, - limit=args['limit'], - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) class MessageFeedbackApi(Resource): @@ -103,49 +117,46 @@ class MessageFeedbackApi(Resource): @get_app_model def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('message_id', required=True, type=uuid_value, location='json') - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("message_id", required=True, type=uuid_value, location="json") + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() - message_id = str(args['message_id']) + message_id = str(args["message_id"]) - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() if not message: raise NotFound("Message Not Exists.") feedback = message.admin_feedback - if not args['rating'] and feedback: + if not args["rating"] and feedback: db.session.delete(feedback) - elif args['rating'] and feedback: - feedback.rating = args['rating'] - elif not args['rating'] and not feedback: - raise ValueError('rating cannot be None when feedback not exists') + elif args["rating"] and feedback: + feedback.rating = args["rating"] + elif not args["rating"] and not feedback: + raise ValueError("rating cannot be None when feedback not exists") else: feedback = MessageFeedback( app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, - rating=args['rating'], - from_source='admin', - from_account_id=current_user.id + rating=args["rating"], + from_source="admin", + from_account_id=current_user.id, ) db.session.add(feedback) db.session.commit() - return {'result': 'success'} + return {"result": "success"} class MessageAnnotationApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @get_app_model @marshal_with(annotation_fields) def post(self, app_model): @@ -153,10 +164,10 @@ class MessageAnnotationApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('message_id', required=False, type=uuid_value, location='json') - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') - parser.add_argument('annotation_reply', required=False, type=dict, location='json') + parser.add_argument("message_id", required=False, type=uuid_value, location="json") + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") + parser.add_argument("annotation_reply", required=False, type=dict, location="json") args = parser.parse_args() annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) @@ -169,11 +180,9 @@ class MessageAnnotationCountApi(Resource): @account_initialization_required @get_app_model def get(self, app_model): - count = db.session.query(MessageAnnotation).filter( - MessageAnnotation.app_id == app_model.id - ).count() + count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count() - return {'count': count} + return {"count": count} class MessageSuggestedQuestionApi(Resource): @@ -186,10 +195,7 @@ class MessageSuggestedQuestionApi(Resource): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - message_id=message_id, - user=current_user, - invoke_from=InvokeFrom.DEBUGGER + app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER ) except MessageNotExistsError: raise NotFound("Message not found") @@ -209,7 +215,7 @@ class MessageSuggestedQuestionApi(Resource): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} class MessageApi(Resource): @@ -221,10 +227,7 @@ class MessageApi(Resource): def get(self, app_model, message_id): message_id = str(message_id) - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() if not message: raise NotFound("Message Not Exists.") @@ -232,9 +235,9 @@ class MessageApi(Resource): return message -api.add_resource(MessageSuggestedQuestionApi, '/apps//chat-messages//suggested-questions') -api.add_resource(ChatMessageListApi, '/apps//chat-messages', endpoint='console_chat_messages') -api.add_resource(MessageFeedbackApi, '/apps//feedbacks') -api.add_resource(MessageAnnotationApi, '/apps//annotations') -api.add_resource(MessageAnnotationCountApi, '/apps//annotations/count') -api.add_resource(MessageApi, '/apps//messages/', endpoint='console_message') +api.add_resource(MessageSuggestedQuestionApi, "/apps//chat-messages//suggested-questions") +api.add_resource(ChatMessageListApi, "/apps//chat-messages", endpoint="console_chat_messages") +api.add_resource(MessageFeedbackApi, "/apps//feedbacks") +api.add_resource(MessageAnnotationApi, "/apps//annotations") +api.add_resource(MessageAnnotationCountApi, "/apps//annotations/count") +api.add_resource(MessageApi, "/apps//messages/", endpoint="console_message") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index c8df879a29..702afe9864 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -19,19 +19,15 @@ from services.app_model_config_service import AppModelConfigService class ModelConfigResource(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model): - """Modify app model config""" # validate config model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, - config=request.json, - app_mode=AppMode.value_of(app_model.mode) + tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode) ) new_app_model_config = AppModelConfig( @@ -41,15 +37,15 @@ class ModelConfigResource(Resource): if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: # get original app model config - original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( - AppModelConfig.id == app_model.app_model_config_id - ).first() + original_app_model_config: AppModelConfig = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + ) agent_mode = original_app_model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input parameter_map = {} masked_parameter_map = {} tool_map = {} - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue @@ -66,7 +62,7 @@ class ModelConfigResource(Resource): tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app_model.id}' + identity_id=f"AGENT.{app_model.id}", ) except Exception as e: continue @@ -79,18 +75,18 @@ class ModelConfigResource(Resource): parameters = {} masked_parameter = {} - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" masked_parameter_map[key] = masked_parameter parameter_map[key] = parameters tool_map[key] = tool_runtime # encrypt agent tool parameters if it's secret-input agent_mode = new_app_model_config.agent_mode_dict - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: agent_tool_entity = AgentToolEntity(**tool) # get tool - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" if key in tool_map: tool_runtime = tool_map[key] else: @@ -108,7 +104,7 @@ class ModelConfigResource(Resource): tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app_model.id}' + identity_id=f"AGENT.{app_model.id}", ) manager.delete_tool_parameters_cache() @@ -116,15 +112,17 @@ class ModelConfigResource(Resource): if agent_tool_entity.tool_parameters: if key not in masked_parameter_map: continue - + for masked_key, masked_value in masked_parameter_map[key].items(): - if masked_key in agent_tool_entity.tool_parameters and \ - agent_tool_entity.tool_parameters[masked_key] == masked_value: + if ( + masked_key in agent_tool_entity.tool_parameters + and agent_tool_entity.tool_parameters[masked_key] == masked_value + ): agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key) # encrypt parameters if agent_tool_entity.tool_parameters: - tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) # update app model config new_app_model_config.agent_mode = json.dumps(agent_mode) @@ -135,12 +133,9 @@ class ModelConfigResource(Resource): app_model.app_model_config_id = new_app_model_config.id db.session.commit() - app_model_config_was_updated.send( - app_model, - app_model_config=new_app_model_config - ) + app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ModelConfigResource, '/apps//model-config') +api.add_resource(ModelConfigResource, "/apps//model-config") diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index c0cf7b9e33..374bd2b815 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -18,13 +18,11 @@ class TraceAppConfigApi(Resource): @account_initialization_required def get(self, app_id): parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='args') + parser.add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() try: - trace_config = OpsService.get_tracing_app_config( - app_id=app_id, tracing_provider=args['tracing_provider'] - ) + trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) if not trace_config: return {"has_not_configured": True} return trace_config @@ -37,19 +35,17 @@ class TraceAppConfigApi(Resource): def post(self, app_id): """Create a new trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='json') - parser.add_argument('tracing_config', type=dict, required=True, location='json') + parser.add_argument("tracing_provider", type=str, required=True, location="json") + parser.add_argument("tracing_config", type=dict, required=True, location="json") args = parser.parse_args() try: result = OpsService.create_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'], - tracing_config=args['tracing_config'] + app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] ) if not result: raise TracingConfigIsExist() - if result.get('error'): + if result.get("error"): raise TracingConfigCheckError() return result except Exception as e: @@ -61,15 +57,13 @@ class TraceAppConfigApi(Resource): def patch(self, app_id): """Update an existing trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='json') - parser.add_argument('tracing_config', type=dict, required=True, location='json') + parser.add_argument("tracing_provider", type=str, required=True, location="json") + parser.add_argument("tracing_config", type=dict, required=True, location="json") args = parser.parse_args() try: result = OpsService.update_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'], - tracing_config=args['tracing_config'] + app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] ) if not result: raise TracingConfigNotExist() @@ -83,14 +77,11 @@ class TraceAppConfigApi(Resource): def delete(self, app_id): """Delete an existing trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='args') + parser.add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() try: - result = OpsService.delete_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'] - ) + result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) if not result: raise TracingConfigNotExist() return {"result": "success"} @@ -98,4 +89,4 @@ class TraceAppConfigApi(Resource): raise e -api.add_resource(TraceAppConfigApi, '/apps//trace-config') +api.add_resource(TraceAppConfigApi, "/apps//trace-config") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 7db58c048a..d903a2609f 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -15,23 +15,23 @@ from models.model import Site def parse_app_site_args(): parser = reqparse.RequestParser() - parser.add_argument('title', type=str, required=False, location='json') - parser.add_argument('icon_type', type=str, required=False, location='json') - parser.add_argument('icon', type=str, required=False, location='json') - parser.add_argument('icon_background', type=str, required=False, location='json') - parser.add_argument('description', type=str, required=False, location='json') - parser.add_argument('default_language', type=supported_language, required=False, location='json') - parser.add_argument('chat_color_theme', type=str, required=False, location='json') - parser.add_argument('chat_color_theme_inverted', type=bool, required=False, location='json') - parser.add_argument('customize_domain', type=str, required=False, location='json') - parser.add_argument('copyright', type=str, required=False, location='json') - parser.add_argument('privacy_policy', type=str, required=False, location='json') - parser.add_argument('custom_disclaimer', type=str, required=False, location='json') - parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'], - required=False, - location='json') - parser.add_argument('prompt_public', type=bool, required=False, location='json') - parser.add_argument('show_workflow_steps', type=bool, required=False, location='json') + parser.add_argument("title", type=str, required=False, location="json") + parser.add_argument("icon_type", type=str, required=False, location="json") + parser.add_argument("icon", type=str, required=False, location="json") + parser.add_argument("icon_background", type=str, required=False, location="json") + parser.add_argument("description", type=str, required=False, location="json") + parser.add_argument("default_language", type=supported_language, required=False, location="json") + parser.add_argument("chat_color_theme", type=str, required=False, location="json") + parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") + parser.add_argument("customize_domain", type=str, required=False, location="json") + parser.add_argument("copyright", type=str, required=False, location="json") + parser.add_argument("privacy_policy", type=str, required=False, location="json") + parser.add_argument("custom_disclaimer", type=str, required=False, location="json") + parser.add_argument( + "customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json" + ) + parser.add_argument("prompt_public", type=bool, required=False, location="json") + parser.add_argument("show_workflow_steps", type=bool, required=False, location="json") return parser.parse_args() @@ -48,26 +48,24 @@ class AppSite(Resource): if not current_user.is_editor: raise Forbidden() - site = db.session.query(Site). \ - filter(Site.app_id == app_model.id). \ - one_or_404() + site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404() for attr_name in [ - 'title', - 'icon_type', - 'icon', - 'icon_background', - 'description', - 'default_language', - 'chat_color_theme', - 'chat_color_theme_inverted', - 'customize_domain', - 'copyright', - 'privacy_policy', - 'custom_disclaimer', - 'customize_token_strategy', - 'prompt_public', - 'show_workflow_steps' + "title", + "icon_type", + "icon", + "icon_background", + "description", + "default_language", + "chat_color_theme", + "chat_color_theme_inverted", + "customize_domain", + "copyright", + "privacy_policy", + "custom_disclaimer", + "customize_token_strategy", + "prompt_public", + "show_workflow_steps", ]: value = args.get(attr_name) if value is not None: @@ -79,7 +77,6 @@ class AppSite(Resource): class AppSiteAccessTokenReset(Resource): - @setup_required @login_required @account_initialization_required @@ -101,5 +98,5 @@ class AppSiteAccessTokenReset(Resource): return site -api.add_resource(AppSite, '/apps//site') -api.add_resource(AppSiteAccessTokenReset, '/apps//site/access-token-reset') +api.add_resource(AppSite, "/apps//site") +api.add_resource(AppSiteAccessTokenReset, "/apps//site/access-token-reset") diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index b882ffef34..bf65efeae4 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -17,7 +17,6 @@ from models.model import AppMode class DailyConversationStatistic(Resource): - @setup_required @login_required @account_initialization_required @@ -26,58 +25,52 @@ class DailyConversationStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'conversation_count': i.conversation_count - }) + response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class DailyTerminalsStatistic(Resource): - @setup_required @login_required @account_initialization_required @@ -86,54 +79,49 @@ class DailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'terminal_count': i.terminal_count - }) + response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class DailyTokenCostStatistic(Resource): @@ -145,58 +133,53 @@ class DailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, (sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count, sum(total_price) as total_price FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'token_count': i.token_count, - 'total_price': i.total_price, - 'currency': 'USD' - }) + response_data.append( + {"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"} + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class AverageSessionInteractionStatistic(Resource): @@ -208,8 +191,8 @@ class AverageSessionInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, @@ -218,30 +201,30 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count FROM conversations c JOIN messages m ON c.id = m.conversation_id WHERE c.override_model_configs IS NULL AND c.app_id = :app_id""" - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and c.created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and c.created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and c.created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and c.created_at < :end" + arg_dict["end"] = end_datetime_utc sql_query += """ GROUP BY m.conversation_id) subquery @@ -250,18 +233,15 @@ GROUP BY date ORDER BY date""" response_data = [] - + with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'interactions': float(i.interactions.quantize(Decimal('0.01'))) - }) + response_data.append( + {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class UserSatisfactionRateStatistic(Resource): @@ -273,57 +253,57 @@ class UserSatisfactionRateStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count FROM messages m LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like' WHERE m.app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and m.created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and m.created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and m.created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and m.created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), - }) + response_data.append( + { + "date": str(i.date), + "rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), + } + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class AverageResponseTimeStatistic(Resource): @@ -335,56 +315,51 @@ class AverageResponseTimeStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, AVG(provider_response_latency) as latency FROM messages WHERE app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'latency': round(i.latency * 1000, 4) - }) + response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class TokensPerSecondStatistic(Resource): @@ -396,63 +371,58 @@ class TokensPerSecondStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, CASE WHEN SUM(provider_response_latency) = 0 THEN 0 ELSE (SUM(answer_tokens) / SUM(provider_response_latency)) END as tokens_per_second FROM messages -WHERE app_id = :app_id''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} +WHERE app_id = :app_id""" + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'tps': round(i.tokens_per_second, 4) - }) + response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) -api.add_resource(DailyConversationStatistic, '/apps//statistics/daily-conversations') -api.add_resource(DailyTerminalsStatistic, '/apps//statistics/daily-end-users') -api.add_resource(DailyTokenCostStatistic, '/apps//statistics/token-costs') -api.add_resource(AverageSessionInteractionStatistic, '/apps//statistics/average-session-interactions') -api.add_resource(UserSatisfactionRateStatistic, '/apps//statistics/user-satisfaction-rate') -api.add_resource(AverageResponseTimeStatistic, '/apps//statistics/average-response-time') -api.add_resource(TokensPerSecondStatistic, '/apps//statistics/tokens-per-second') +api.add_resource(DailyConversationStatistic, "/apps//statistics/daily-conversations") +api.add_resource(DailyTerminalsStatistic, "/apps//statistics/daily-end-users") +api.add_resource(DailyTokenCostStatistic, "/apps//statistics/token-costs") +api.add_resource(AverageSessionInteractionStatistic, "/apps//statistics/average-session-interactions") +api.add_resource(UserSatisfactionRateStatistic, "/apps//statistics/user-satisfaction-rate") +api.add_resource(AverageResponseTimeStatistic, "/apps//statistics/average-response-time") +api.add_resource(TokensPerSecondStatistic, "/apps//statistics/tokens-per-second") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 42cf1b7981..e44820f634 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -64,51 +64,51 @@ class DraftWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - - content_type = request.headers.get('Content-Type', '') - if 'application/json' in content_type: + content_type = request.headers.get("Content-Type", "") + + if "application/json" in content_type: parser = reqparse.RequestParser() - parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') - parser.add_argument('features', type=dict, required=True, nullable=False, location='json') - parser.add_argument('hash', type=str, required=False, location='json') + parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") + parser.add_argument("features", type=dict, required=True, nullable=False, location="json") + parser.add_argument("hash", type=str, required=False, location="json") # TODO: set this to required=True after frontend is updated - parser.add_argument('environment_variables', type=list, required=False, location='json') - parser.add_argument('conversation_variables', type=list, required=False, location='json') + parser.add_argument("environment_variables", type=list, required=False, location="json") + parser.add_argument("conversation_variables", type=list, required=False, location="json") args = parser.parse_args() - elif 'text/plain' in content_type: + elif "text/plain" in content_type: try: - data = json.loads(request.data.decode('utf-8')) - if 'graph' not in data or 'features' not in data: - raise ValueError('graph or features not found in data') + data = json.loads(request.data.decode("utf-8")) + if "graph" not in data or "features" not in data: + raise ValueError("graph or features not found in data") - if not isinstance(data.get('graph'), dict) or not isinstance(data.get('features'), dict): - raise ValueError('graph or features is not a dict') + if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict): + raise ValueError("graph or features is not a dict") args = { - 'graph': data.get('graph'), - 'features': data.get('features'), - 'hash': data.get('hash'), - 'environment_variables': data.get('environment_variables'), - 'conversation_variables': data.get('conversation_variables'), + "graph": data.get("graph"), + "features": data.get("features"), + "hash": data.get("hash"), + "environment_variables": data.get("environment_variables"), + "conversation_variables": data.get("conversation_variables"), } except json.JSONDecodeError: - return {'message': 'Invalid JSON data'}, 400 + return {"message": "Invalid JSON data"}, 400 else: abort(415) workflow_service = WorkflowService() try: - environment_variables_list = args.get('environment_variables') or [] + environment_variables_list = args.get("environment_variables") or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] - conversation_variables_list = args.get('conversation_variables') or [] + conversation_variables_list = args.get("conversation_variables") or [] conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] workflow = workflow_service.sync_draft_workflow( app_model=app_model, - graph=args['graph'], - features=args['features'], - unique_hash=args.get('hash'), + graph=args["graph"], + features=args["features"], + unique_hash=args.get("hash"), account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, @@ -119,7 +119,7 @@ class DraftWorkflowApi(Resource): return { "result": "success", "hash": workflow.unique_hash, - "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), } @@ -138,13 +138,11 @@ class DraftWorkflowImportApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('data', type=str, required=True, nullable=False, location='json') + parser.add_argument("data", type=str, required=True, nullable=False, location="json") args = parser.parse_args() workflow = AppDslService.import_and_overwrite_workflow( - app_model=app_model, - data=args['data'], - account=current_user + app_model=app_model, data=args["data"], account=current_user ) return workflow @@ -162,21 +160,17 @@ class AdvancedChatDraftWorkflowRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') - parser.add_argument('query', type=str, required=True, location='json', default='') - parser.add_argument('files', type=list, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') + parser.add_argument("inputs", type=dict, location="json") + parser.add_argument("query", type=str, required=True, location="json", default="") + parser.add_argument("files", type=list, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True ) return helper.compact_generate_response(response) @@ -190,6 +184,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource): logging.exception("internal server error.") raise InternalServerError() + class AdvancedChatDraftRunIterationNodeApi(Resource): @setup_required @login_required @@ -202,18 +197,14 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') + parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: response = AppGenerateService.generate_single_iteration( - app_model=app_model, - user=current_user, - node_id=node_id, - args=args, - streaming=True + app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) return helper.compact_generate_response(response) @@ -240,18 +231,14 @@ class WorkflowDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') + parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: response = AppGenerateService.generate_single_iteration( - app_model=app_model, - user=current_user, - node_id=node_id, - args=args, - streaming=True + app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) return helper.compact_generate_response(response) @@ -265,6 +252,7 @@ class WorkflowDraftRunIterationNodeApi(Resource): logging.exception("internal server error.") raise InternalServerError() + class DraftWorkflowRunApi(Resource): @setup_required @login_required @@ -277,19 +265,15 @@ class DraftWorkflowRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True ) return helper.compact_generate_response(response) @@ -312,12 +296,10 @@ class WorkflowTaskStopApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) - return { - "result": "success" - } + return {"result": "success"} class DraftWorkflowNodeRunApi(Resource): @@ -333,24 +315,20 @@ class DraftWorkflowNodeRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() workflow_service = WorkflowService() workflow_node_execution = workflow_service.run_draft_workflow_node( - app_model=app_model, - node_id=node_id, - user_inputs=args.get('inputs'), - account=current_user + app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user ) return workflow_node_execution class PublishedWorkflowApi(Resource): - @setup_required @login_required @account_initialization_required @@ -363,7 +341,7 @@ class PublishedWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + # fetch published workflow by app_model workflow_service = WorkflowService() workflow = workflow_service.get_published_workflow(app_model=app_model) @@ -382,14 +360,11 @@ class PublishedWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + workflow_service = WorkflowService() workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) - return { - "result": "success", - "created_at": TimestampField().format(workflow.created_at) - } + return {"result": "success", "created_at": TimestampField().format(workflow.created_at)} class DefaultBlockConfigsApi(Resource): @@ -404,7 +379,7 @@ class DefaultBlockConfigsApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + # Get default block configs workflow_service = WorkflowService() return workflow_service.get_default_block_configs() @@ -422,24 +397,21 @@ class DefaultBlockConfigApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('q', type=str, location='args') + parser.add_argument("q", type=str, location="args") args = parser.parse_args() filters = None - if args.get('q'): + if args.get("q"): try: - filters = json.loads(args.get('q')) + filters = json.loads(args.get("q")) except json.JSONDecodeError: - raise ValueError('Invalid filters') + raise ValueError("Invalid filters") # Get default block configs workflow_service = WorkflowService() - return workflow_service.get_default_block_config( - node_type=block_type, - filters=filters - ) + return workflow_service.get_default_block_config(node_type=block_type, filters=filters) class ConvertToWorkflowApi(Resource): @@ -456,41 +428,43 @@ class ConvertToWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + if request.data: parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, nullable=True, location='json') - parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json') - parser.add_argument('icon', type=str, required=False, nullable=True, location='json') - parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json') + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") args = parser.parse_args() else: args = {} # convert to workflow mode workflow_service = WorkflowService() - new_app_model = workflow_service.convert_to_workflow( - app_model=app_model, - account=current_user, - args=args - ) + new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args) # return app id return { - 'new_app_id': new_app_model.id, + "new_app_id": new_app_model.id, } -api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') -api.add_resource(DraftWorkflowImportApi, '/apps//workflows/draft/import') -api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') -api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') -api.add_resource(WorkflowTaskStopApi, '/apps//workflow-runs/tasks//stop') -api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') -api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps//advanced-chat/workflows/draft/iteration/nodes//run') -api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps//workflows/draft/iteration/nodes//run') -api.add_resource(PublishedWorkflowApi, '/apps//workflows/publish') -api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') -api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' - '/') -api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') +api.add_resource(DraftWorkflowApi, "/apps//workflows/draft") +api.add_resource(DraftWorkflowImportApi, "/apps//workflows/draft/import") +api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps//advanced-chat/workflows/draft/run") +api.add_resource(DraftWorkflowRunApi, "/apps//workflows/draft/run") +api.add_resource(WorkflowTaskStopApi, "/apps//workflow-runs/tasks//stop") +api.add_resource(DraftWorkflowNodeRunApi, "/apps//workflows/draft/nodes//run") +api.add_resource( + AdvancedChatDraftRunIterationNodeApi, + "/apps//advanced-chat/workflows/draft/iteration/nodes//run", +) +api.add_resource( + WorkflowDraftRunIterationNodeApi, "/apps//workflows/draft/iteration/nodes//run" +) +api.add_resource(PublishedWorkflowApi, "/apps//workflows/publish") +api.add_resource(DefaultBlockConfigsApi, "/apps//workflows/default-workflow-block-configs") +api.add_resource( + DefaultBlockConfigApi, "/apps//workflows/default-workflow-block-configs" "/" +) +api.add_resource(ConvertToWorkflowApi, "/apps//convert-to-workflow") diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 6d1709ed8e..dc962409cc 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -22,20 +22,19 @@ class WorkflowAppLogApi(Resource): Get workflow app logs """ parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args') - parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") + parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() # get paginate workflow app logs workflow_app_service = WorkflowAppService() workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( - app_model=app_model, - args=args + app_model=app_model, args=args ) return workflow_app_log_pagination -api.add_resource(WorkflowAppLogApi, '/apps//workflow-app-logs') +api.add_resource(WorkflowAppLogApi, "/apps//workflow-app-logs") diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 35d982e37c..a055d03deb 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -28,15 +28,12 @@ class AdvancedChatAppWorkflowRunListApi(Resource): Get advanced chat app workflow run list """ parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_advanced_chat_workflow_runs( - app_model=app_model, - args=args - ) + result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args) return result @@ -52,15 +49,12 @@ class WorkflowRunListApi(Resource): Get workflow run list """ parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_workflow_runs( - app_model=app_model, - args=args - ) + result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args) return result @@ -98,12 +92,10 @@ class WorkflowRunNodeExecutionListApi(Resource): workflow_run_service = WorkflowRunService() node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) - return { - 'data': node_executions - } + return {"data": node_executions} -api.add_resource(AdvancedChatAppWorkflowRunListApi, '/apps//advanced-chat/workflow-runs') -api.add_resource(WorkflowRunListApi, '/apps//workflow-runs') -api.add_resource(WorkflowRunDetailApi, '/apps//workflow-runs/') -api.add_resource(WorkflowRunNodeExecutionListApi, '/apps//workflow-runs//node-executions') +api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps//advanced-chat/workflow-runs") +api.add_resource(WorkflowRunListApi, "/apps//workflow-runs") +api.add_resource(WorkflowRunDetailApi, "/apps//workflow-runs/") +api.add_resource(WorkflowRunNodeExecutionListApi, "/apps//workflow-runs//node-executions") diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 1d7dc395ff..db2f683589 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -26,56 +26,56 @@ class WorkflowDailyRunsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs FROM workflow_runs WHERE app_id = :app_id AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + """ + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'runs': i.runs - }) + response_data.append({"date": str(i.date), "runs": i.runs}) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowDailyTerminalsStatistic(Resource): @setup_required @@ -86,56 +86,56 @@ class WorkflowDailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count FROM workflow_runs WHERE app_id = :app_id AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + """ + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'terminal_count': i.terminal_count - }) + response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowDailyTokenCostStatistic(Resource): @setup_required @@ -146,58 +146,63 @@ class WorkflowDailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SUM(workflow_runs.total_tokens) as token_count FROM workflow_runs WHERE app_id = :app_id AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + """ + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'token_count': i.token_count, - }) + response_data.append( + { + "date": str(i.date), + "token_count": i.token_count, + } + ) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowAverageAppInteractionStatistic(Resource): @setup_required @@ -208,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -229,50 +234,54 @@ class WorkflowAverageAppInteractionStatistic(Resource): GROUP BY date, c.created_by) sub GROUP BY sub.date """ - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query = sql_query.replace('{{start}}', ' AND c.created_at >= :start') - arg_dict['start'] = start_datetime_utc + sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start") + arg_dict["start"] = start_datetime_utc else: - sql_query = sql_query.replace('{{start}}', '') + sql_query = sql_query.replace("{{start}}", "") - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query = sql_query.replace('{{end}}', ' and c.created_at < :end') - arg_dict['end'] = end_datetime_utc + sql_query = sql_query.replace("{{end}}", " and c.created_at < :end") + arg_dict["end"] = end_datetime_utc else: - sql_query = sql_query.replace('{{end}}', '') + sql_query = sql_query.replace("{{end}}", "") response_data = [] - + with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'interactions': float(i.interactions.quantize(Decimal('0.01'))) - }) + response_data.append( + {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) -api.add_resource(WorkflowDailyRunsStatistic, '/apps//workflow/statistics/daily-conversations') -api.add_resource(WorkflowDailyTerminalsStatistic, '/apps//workflow/statistics/daily-terminals') -api.add_resource(WorkflowDailyTokenCostStatistic, '/apps//workflow/statistics/token-costs') -api.add_resource(WorkflowAverageAppInteractionStatistic, '/apps//workflow/statistics/average-app-interactions') + +api.add_resource(WorkflowDailyRunsStatistic, "/apps//workflow/statistics/daily-conversations") +api.add_resource(WorkflowDailyTerminalsStatistic, "/apps//workflow/statistics/daily-terminals") +api.add_resource(WorkflowDailyTokenCostStatistic, "/apps//workflow/statistics/token-costs") +api.add_resource( + WorkflowAverageAppInteractionStatistic, "/apps//workflow/statistics/average-app-interactions" +) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index d61ab6d6ae..5e0a4bc814 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -8,24 +8,23 @@ from libs.login import current_user from models.model import App, AppMode -def get_app_model(view: Optional[Callable] = None, *, - mode: Union[AppMode, list[AppMode]] = None): +def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): - if not kwargs.get('app_id'): - raise ValueError('missing app_id in path parameters') + if not kwargs.get("app_id"): + raise ValueError("missing app_id in path parameters") - app_id = kwargs.get('app_id') + app_id = kwargs.get("app_id") app_id = str(app_id) - del kwargs['app_id'] + del kwargs["app_id"] - app_model = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app_model = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app_model: raise AppNotFoundError() @@ -44,9 +43,10 @@ def get_app_model(view: Optional[Callable] = None, *, mode_values = {m.value for m in modes} raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") - kwargs['app_model'] = app_model + kwargs["app_model"] = app_model return view_func(*args, **kwargs) + return decorated_view if view is None: diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 8efb55cdb6..8ba6b53e7e 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -17,60 +17,61 @@ from services.account_service import RegisterService class ActivateCheckApi(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args') - parser.add_argument('email', type=email, required=False, nullable=True, location='args') - parser.add_argument('token', type=str, required=True, nullable=False, location='args') + parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args") + parser.add_argument("email", type=email, required=False, nullable=True, location="args") + parser.add_argument("token", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - workspaceId = args['workspace_id'] - reg_email = args['email'] - token = args['token'] + workspaceId = args["workspace_id"] + reg_email = args["email"] + token = args["token"] invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) - return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None} + return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None} class ActivateApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json') - parser.add_argument('email', type=email, required=False, nullable=True, location='json') - parser.add_argument('token', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json') - parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json') - parser.add_argument('interface_language', type=supported_language, required=True, nullable=False, - location='json') - parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json') + parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") + parser.add_argument("email", type=email, required=False, nullable=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json") + parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument( + "interface_language", type=supported_language, required=True, nullable=False, location="json" + ) + parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") args = parser.parse_args() - invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token']) + invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) if invitation is None: raise AlreadyActivateError() - RegisterService.revoke_token(args['workspace_id'], args['email'], args['token']) + RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"]) - account = invitation['account'] - account.name = args['name'] + account = invitation["account"] + account.name = args["name"] # generate password salt salt = secrets.token_bytes(16) base64_salt = base64.b64encode(salt).decode() # encrypt password with salt - password_hashed = hash_password(args['password'], salt) + password_hashed = hash_password(args["password"], salt) base64_password_hashed = base64.b64encode(password_hashed).decode() account.password = base64_password_hashed account.password_salt = base64_salt - account.interface_language = args['interface_language'] - account.timezone = args['timezone'] - account.interface_theme = 'light' + account.interface_language = args["interface_language"] + account.timezone = args["timezone"] + account.interface_theme = "light" account.status = AccountStatus.ACTIVE.value account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ActivateCheckApi, '/activate/check') -api.add_resource(ActivateApi, '/activate') +api.add_resource(ActivateCheckApi, "/activate/check") +api.add_resource(ActivateApi, "/activate") diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index f79b93b74f..50db6eebc1 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -19,18 +19,19 @@ class ApiKeyAuthDataSource(Resource): data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) if data_source_api_key_bindings: return { - 'sources': [{ - 'id': data_source_api_key_binding.id, - 'category': data_source_api_key_binding.category, - 'provider': data_source_api_key_binding.provider, - 'disabled': data_source_api_key_binding.disabled, - 'created_at': int(data_source_api_key_binding.created_at.timestamp()), - 'updated_at': int(data_source_api_key_binding.updated_at.timestamp()), - } - for data_source_api_key_binding in - data_source_api_key_bindings] + "sources": [ + { + "id": data_source_api_key_binding.id, + "category": data_source_api_key_binding.category, + "provider": data_source_api_key_binding.provider, + "disabled": data_source_api_key_binding.disabled, + "created_at": int(data_source_api_key_binding.created_at.timestamp()), + "updated_at": int(data_source_api_key_binding.updated_at.timestamp()), + } + for data_source_api_key_binding in data_source_api_key_bindings + ] } - return {'sources': []} + return {"sources": []} class ApiKeyAuthDataSourceBinding(Resource): @@ -42,16 +43,16 @@ class ApiKeyAuthDataSourceBinding(Resource): if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('category', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("category", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() ApiKeyAuthService.validate_api_key_auth_args(args) try: ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) except Exception as e: raise ApiKeyAuthFailedError(str(e)) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ApiKeyAuthDataSourceBindingDelete(Resource): @@ -65,9 +66,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source') -api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding') -api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/') +api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source") +api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding") +api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/") diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 45cfa9d7eb..fd31e5ccc3 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -17,13 +17,13 @@ from ..wraps import account_initialization_required def get_oauth_providers(): with current_app.app_context(): - notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID, - client_secret=dify_config.NOTION_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion') + notion_oauth = NotionOAuth( + client_id=dify_config.NOTION_CLIENT_ID, + client_secret=dify_config.NOTION_CLIENT_SECRET, + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", + ) - OAUTH_PROVIDERS = { - 'notion': notion_oauth - } + OAUTH_PROVIDERS = {"notion": notion_oauth} return OAUTH_PROVIDERS @@ -37,18 +37,16 @@ class OAuthDataSource(Resource): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) print(vars(oauth_provider)) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if dify_config.NOTION_INTEGRATION_TYPE == 'internal': + return {"error": "Invalid provider"}, 400 + if dify_config.NOTION_INTEGRATION_TYPE == "internal": internal_secret = dify_config.NOTION_INTERNAL_SECRET if not internal_secret: - return {'error': 'Internal secret is not set'}, + return ({"error": "Internal secret is not set"},) oauth_provider.save_internal_access_token(internal_secret) - return { 'data': '' } + return {"data": ""} else: auth_url = oauth_provider.get_authorization_url() - return { 'data': auth_url }, 200 - - + return {"data": auth_url}, 200 class OAuthDataSourceCallback(Resource): @@ -57,18 +55,18 @@ class OAuthDataSourceCallback(Resource): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if 'code' in request.args: - code = request.args.get('code') + return {"error": "Invalid provider"}, 400 + if "code" in request.args: + code = request.args.get("code") - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}') - elif 'error' in request.args: - error = request.args.get('error') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}") + elif "error" in request.args: + error = request.args.get("error") - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}") else: - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied') - + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") + class OAuthDataSourceBinding(Resource): def get(self, provider: str): @@ -76,17 +74,18 @@ class OAuthDataSourceBinding(Resource): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if 'code' in request.args: - code = request.args.get('code') + return {"error": "Invalid provider"}, 400 + if "code" in request.args: + code = request.args.get("code") try: oauth_provider.get_access_token(code) except requests.exceptions.HTTPError as e: logging.exception( - f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") - return {'error': 'OAuth data source process failed'}, 400 + f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}" + ) + return {"error": "OAuth data source process failed"}, 400 - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class OAuthDataSourceSync(Resource): @@ -100,18 +99,17 @@ class OAuthDataSourceSync(Resource): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 try: oauth_provider.sync_data_source(binding_id) except requests.exceptions.HTTPError as e: - logging.exception( - f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") - return {'error': 'OAuth data source process failed'}, 400 + logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") + return {"error": "OAuth data source process failed"}, 400 - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(OAuthDataSource, '/oauth/data-source/') -api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/') -api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/') -api.add_resource(OAuthDataSourceSync, '/oauth/data-source///sync') +api.add_resource(OAuthDataSource, "/oauth/data-source/") +api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") +api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") +api.add_resource(OAuthDataSourceSync, "/oauth/data-source///sync") diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index 53dab3298f..ea23e097d0 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -2,31 +2,30 @@ from libs.exception import BaseHTTPException class ApiKeyAuthFailedError(BaseHTTPException): - error_code = 'auth_failed' + error_code = "auth_failed" description = "{message}" code = 500 class InvalidEmailError(BaseHTTPException): - error_code = 'invalid_email' + error_code = "invalid_email" description = "The email address is not valid." code = 400 class PasswordMismatchError(BaseHTTPException): - error_code = 'password_mismatch' + error_code = "password_mismatch" description = "The passwords do not match." code = 400 class InvalidTokenError(BaseHTTPException): - error_code = 'invalid_or_expired_token' + error_code = "invalid_or_expired_token" description = "The token is invalid or has expired." code = 400 class PasswordResetRateLimitExceededError(BaseHTTPException): - error_code = 'password_reset_rate_limit_exceeded' + error_code = "password_reset_rate_limit_exceeded" description = "Password reset rate limit exceeded. Try again later." code = 429 - diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index d78be770ab..0b01a4906a 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -21,14 +21,13 @@ from services.errors.account import RateLimitExceededError class ForgotPasswordSendEmailApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('email', type=str, required=True, location='json') + parser.add_argument("email", type=str, required=True, location="json") args = parser.parse_args() - email = args['email'] + email = args["email"] if not email_validate(email): raise InvalidEmailError() @@ -49,38 +48,36 @@ class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordCheckApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=True, nullable=False, location='json') + parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - token = args['token'] + token = args["token"] reset_data = AccountService.get_reset_password_data(token) if reset_data is None: - return {'is_valid': False, 'email': None} - return {'is_valid': True, 'email': reset_data.get('email')} + return {"is_valid": False, "email": None} + return {"is_valid": True, "email": reset_data.get("email")} class ForgotPasswordResetApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=True, nullable=False, location='json') - parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json') - parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json') + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") args = parser.parse_args() - new_password = args['new_password'] - password_confirm = args['password_confirm'] + new_password = args["new_password"] + password_confirm = args["password_confirm"] if str(new_password).strip() != str(password_confirm).strip(): raise PasswordMismatchError() - token = args['token'] + token = args["token"] reset_data = AccountService.get_reset_password_data(token) if reset_data is None: @@ -94,14 +91,14 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(new_password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account = Account.query.filter_by(email=reset_data.get('email')).first() + account = Account.query.filter_by(email=reset_data.get("email")).first() account.password = base64_password_hashed account.password_salt = base64_salt db.session.commit() - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password') -api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity') -api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets') +api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") +api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") +api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index c135ece67e..62837af2b9 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -20,37 +20,39 @@ class LoginApi(Resource): def post(self): """Authenticate user and login.""" parser = reqparse.RequestParser() - parser.add_argument('email', type=email, required=True, location='json') - parser.add_argument('password', type=valid_password, required=True, location='json') - parser.add_argument('remember_me', type=bool, required=False, default=False, location='json') + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") + parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") args = parser.parse_args() # todo: Verify the recaptcha try: - account = AccountService.authenticate(args['email'], args['password']) + account = AccountService.authenticate(args["email"], args["password"]) except services.errors.account.AccountLoginError as e: - return {'code': 'unauthorized', 'message': str(e)}, 401 + return {"code": "unauthorized", "message": str(e)}, 401 # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: - return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'} + return { + "result": "fail", + "data": "workspace not found, please contact system admin to invite you to join in a workspace", + } token = AccountService.login(account, ip_address=get_remote_ip(request)) - return {'result': 'success', 'data': token} + return {"result": "success", "data": token} class LogoutApi(Resource): - @setup_required def get(self): account = cast(Account, flask_login.current_user) - token = request.headers.get('Authorization', '').split(' ')[1] + token = request.headers.get("Authorization", "").split(" ")[1] AccountService.logout(account=account, token=token) flask_login.logout_user() - return {'result': 'success'} + return {"result": "success"} class ResetPasswordApi(Resource): @@ -80,11 +82,11 @@ class ResetPasswordApi(Resource): # 'subject': 'Reset your Dify password', # 'html': """ #

Dear User,

- #

The Dify team has generated a new password for you, details as follows:

+ #

The Dify team has generated a new password for you, details as follows:

#

{new_password}

#

Please change your password to log in as soon as possible.

#

Regards,

- #

The Dify Team

+ #

The Dify Team

# """ # } @@ -101,8 +103,8 @@ class ResetPasswordApi(Resource): # # handle error # pass - return {'result': 'success'} + return {"result": "success"} -api.add_resource(LoginApi, '/login') -api.add_resource(LogoutApi, '/logout') +api.add_resource(LoginApi, "/login") +api.add_resource(LogoutApi, "/logout") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 4a651bfe7b..ae1b49f3ec 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -25,7 +25,7 @@ def get_oauth_providers(): github_oauth = GitHubOAuth( client_id=dify_config.GITHUB_CLIENT_ID, client_secret=dify_config.GITHUB_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github', + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github", ) if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET: google_oauth = None @@ -33,10 +33,10 @@ def get_oauth_providers(): google_oauth = GoogleOAuth( client_id=dify_config.GOOGLE_CLIENT_ID, client_secret=dify_config.GOOGLE_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google', + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", ) - OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth} + OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} return OAUTH_PROVIDERS @@ -47,7 +47,7 @@ class OAuthLogin(Resource): oauth_provider = OAUTH_PROVIDERS.get(provider) print(vars(oauth_provider)) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 auth_url = oauth_provider.get_authorization_url() return redirect(auth_url) @@ -59,20 +59,20 @@ class OAuthCallback(Resource): with current_app.app_context(): oauth_provider = OAUTH_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 - code = request.args.get('code') + code = request.args.get("code") try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) except requests.exceptions.HTTPError as e: - logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}') - return {'error': 'OAuth process failed'}, 400 + logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") + return {"error": "OAuth process failed"}, 400 account = _generate_account(provider, user_info) # Check account status if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: - return {'error': 'Account is banned or closed.'}, 403 + return {"error": "Account is banned or closed."}, 403 if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value @@ -83,7 +83,7 @@ class OAuthCallback(Resource): token = AccountService.login(account, ip_address=get_remote_ip(request)) - return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}") def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: @@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): if not account: # Create account - account_name = user_info.name if user_info.name else 'Dify' + account_name = user_info.name if user_info.name else "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider ) @@ -121,5 +121,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): return account -api.add_resource(OAuthLogin, '/oauth/login/') -api.add_resource(OAuthCallback, '/oauth/authorize/') +api.add_resource(OAuthLogin, "/oauth/login/") +api.add_resource(OAuthCallback, "/oauth/authorize/") diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 72a6129efa..9a1d914869 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -9,28 +9,24 @@ from services.billing_service import BillingService class Subscription(Resource): - @setup_required @login_required @account_initialization_required @only_edition_cloud def get(self): - parser = reqparse.RequestParser() - parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team']) - parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) + parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) + parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) args = parser.parse_args() BillingService.is_tenant_owner_or_admin(current_user) - return BillingService.get_subscription(args['plan'], - args['interval'], - current_user.email, - current_user.current_tenant_id) + return BillingService.get_subscription( + args["plan"], args["interval"], current_user.email, current_user.current_tenant_id + ) class Invoices(Resource): - @setup_required @login_required @account_initialization_required @@ -40,5 +36,5 @@ class Invoices(Resource): return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) -api.add_resource(Subscription, '/billing/subscription') -api.add_resource(Invoices, '/billing/invoices') +api.add_resource(Subscription, "/billing/subscription") +api.add_resource(Invoices, "/billing/invoices") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 0ca0f0a856..0e1acab946 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -22,19 +22,22 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task class DataSourceApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(integrate_list_fields) def get(self): # get workspace data source integrates - data_source_integrates = db.session.query(DataSourceOauthBinding).filter( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.disabled == False - ).all() + data_source_integrates = ( + db.session.query(DataSourceOauthBinding) + .filter( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.disabled == False, + ) + .all() + ) - base_url = request.url_root.rstrip('/') + base_url = request.url_root.rstrip("/") data_source_oauth_base_path = "/console/api/oauth/data-source" providers = ["notion"] @@ -44,26 +47,30 @@ class DataSourceApi(Resource): existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates) if existing_integrates: for existing_integrate in list(existing_integrates): - integrate_data.append({ - 'id': existing_integrate.id, - 'provider': provider, - 'created_at': existing_integrate.created_at, - 'is_bound': True, - 'disabled': existing_integrate.disabled, - 'source_info': existing_integrate.source_info, - 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' - }) + integrate_data.append( + { + "id": existing_integrate.id, + "provider": provider, + "created_at": existing_integrate.created_at, + "is_bound": True, + "disabled": existing_integrate.disabled, + "source_info": existing_integrate.source_info, + "link": f"{base_url}{data_source_oauth_base_path}/{provider}", + } + ) else: - integrate_data.append({ - 'id': None, - 'provider': provider, - 'created_at': None, - 'source_info': None, - 'is_bound': False, - 'disabled': None, - 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' - }) - return {'data': integrate_data}, 200 + integrate_data.append( + { + "id": None, + "provider": provider, + "created_at": None, + "source_info": None, + "is_bound": False, + "disabled": None, + "link": f"{base_url}{data_source_oauth_base_path}/{provider}", + } + ) + return {"data": integrate_data}, 200 @setup_required @login_required @@ -71,92 +78,82 @@ class DataSourceApi(Resource): def patch(self, binding_id, action): binding_id = str(binding_id) action = str(action) - data_source_binding = DataSourceOauthBinding.query.filter_by( - id=binding_id - ).first() + data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first() if data_source_binding is None: - raise NotFound('Data source binding not found.') + raise NotFound("Data source binding not found.") # enable binding - if action == 'enable': + if action == "enable": if data_source_binding.disabled: data_source_binding.disabled = False data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: - raise ValueError('Data source is not disabled.') + raise ValueError("Data source is not disabled.") # disable binding - if action == 'disable': + if action == "disable": if not data_source_binding.disabled: data_source_binding.disabled = True data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: - raise ValueError('Data source is disabled.') - return {'result': 'success'}, 200 + raise ValueError("Data source is disabled.") + return {"result": "success"}, 200 class DataSourceNotionListApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(integrate_notion_info_list_fields) def get(self): - dataset_id = request.args.get('dataset_id', default=None, type=str) + dataset_id = request.args.get("dataset_id", default=None, type=str) exist_page_ids = [] # import notion in the exist dataset if dataset_id: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') - if dataset.data_source_type != 'notion_import': - raise ValueError('Dataset is not notion type.') + raise NotFound("Dataset not found.") + if dataset.data_source_type != "notion_import": + raise ValueError("Dataset is not notion type.") documents = Document.query.filter_by( dataset_id=dataset_id, tenant_id=current_user.current_tenant_id, - data_source_type='notion_import', - enabled=True + data_source_type="notion_import", + enabled=True, ).all() if documents: for document in documents: data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info['notion_page_id']) + exist_page_ids.append(data_source_info["notion_page_id"]) # get all authorized pages data_source_bindings = DataSourceOauthBinding.query.filter_by( - tenant_id=current_user.current_tenant_id, - provider='notion', - disabled=False + tenant_id=current_user.current_tenant_id, provider="notion", disabled=False ).all() if not data_source_bindings: - return { - 'notion_info': [] - }, 200 + return {"notion_info": []}, 200 pre_import_info_list = [] for data_source_binding in data_source_bindings: source_info = data_source_binding.source_info - pages = source_info['pages'] + pages = source_info["pages"] # Filter out already bound pages for page in pages: - if page['page_id'] in exist_page_ids: - page['is_bound'] = True + if page["page_id"] in exist_page_ids: + page["is_bound"] = True else: - page['is_bound'] = False + page["is_bound"] = False pre_import_info = { - 'workspace_name': source_info['workspace_name'], - 'workspace_icon': source_info['workspace_icon'], - 'workspace_id': source_info['workspace_id'], - 'pages': pages, + "workspace_name": source_info["workspace_name"], + "workspace_icon": source_info["workspace_icon"], + "workspace_id": source_info["workspace_id"], + "pages": pages, } pre_import_info_list.append(pre_import_info) - return { - 'notion_info': pre_import_info_list - }, 200 + return {"notion_info": pre_import_info_list}, 200 class DataSourceNotionApi(Resource): - @setup_required @login_required @account_initialization_required @@ -166,64 +163,67 @@ class DataSourceNotionApi(Resource): data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise NotFound('Data source binding not found.') + raise NotFound("Data source binding not found.") extractor = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=data_source_binding.access_token, - tenant_id=current_user.current_tenant_id + tenant_id=current_user.current_tenant_id, ) text_docs = extractor.extract() - return { - 'content': "\n".join([doc.page_content for doc in text_docs]) - }, 200 + return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200 @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') + parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) - notion_info_list = args['notion_info_list'] + notion_info_list = args["notion_info_list"] extract_settings = [] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] - for page in notion_info['pages']: + workspace_id = notion_info["workspace_id"] + for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ "notion_workspace_id": workspace_id, - "notion_obj_id": page['page_id'], - "notion_page_type": page['type'], - "tenant_id": current_user.current_tenant_id + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) indexing_runner = IndexingRunner() - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - args['process_rule'], args['doc_form'], - args['doc_language']) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + args["process_rule"], + args["doc_form"], + args["doc_language"], + ) return response, 200 class DataSourceNotionDatasetSyncApi(Resource): - @setup_required @login_required @account_initialization_required @@ -240,7 +240,6 @@ class DataSourceNotionDatasetSyncApi(Resource): class DataSourceNotionDocumentSyncApi(Resource): - @setup_required @login_required @account_initialization_required @@ -258,10 +257,14 @@ class DataSourceNotionDocumentSyncApi(Resource): return 200 -api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates//') -api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages') -api.add_resource(DataSourceNotionApi, - '/notion/workspaces//pages///preview', - '/datasets/notion-indexing-estimate') -api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets//notion/sync') -api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets//documents//notion/sync') +api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates//") +api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages") +api.add_resource( + DataSourceNotionApi, + "/notion/workspaces//pages///preview", + "/datasets/notion-indexing-estimate", +) +api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets//notion/sync") +api.add_resource( + DataSourceNotionDocumentSyncApi, "/datasets//documents//notion/sync" +) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index b9a1c25154..d369730594 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -31,45 +31,40 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 40 characters.') + raise ValueError("Name must be between 1 to 40 characters.") return name def _validate_description_length(description): if len(description) > 400: - raise ValueError('Description cannot exceed 400 characters.') + raise ValueError("Description cannot exceed 400 characters.") return description class DatasetListApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - ids = request.args.getlist('ids') - provider = request.args.get('provider', default="vendor") - search = request.args.get('keyword', default=None, type=str) - tag_ids = request.args.getlist('tag_ids') + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + ids = request.args.getlist("ids") + provider = request.args.get("provider", default="vendor") + search = request.args.get("keyword", default=None, type=str) + tag_ids = request.args.getlist("tag_ids") if ids: datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) else: - datasets, total = DatasetService.get_datasets(page, limit, provider, - current_user.current_tenant_id, current_user, search, tag_ids) + datasets, total = DatasetService.get_datasets( + page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids + ) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: @@ -77,28 +72,22 @@ class DatasetListApi(Resource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item['indexing_technique'] == 'high_quality': + if item["indexing_technique"] == "high_quality": item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: - item['embedding_available'] = True + item["embedding_available"] = True else: - item['embedding_available'] = False + item["embedding_available"] = False else: - item['embedding_available'] = True + item["embedding_available"] = True - if item.get('permission') == 'partial_members': - part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id']) - item.update({'partial_member_list': part_users_list}) + if item.get("permission") == "partial_members": + part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"]) + item.update({"partial_member_list": part_users_list}) else: - item.update({'partial_member_list': []}) + item.update({"partial_member_list": []}) - response = { - 'data': data, - 'has_more': len(datasets) == limit, - 'limit': limit, - 'total': total, - 'page': page - } + response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 @setup_required @@ -106,13 +95,21 @@ class DatasetListApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help='Invalid indexing technique.') + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", + ) args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator @@ -122,9 +119,9 @@ class DatasetListApi(Resource): try: dataset = DatasetService.create_empty_dataset( tenant_id=current_user.current_tenant_id, - name=args['name'], - indexing_technique=args['indexing_technique'], - account=current_user + name=args["name"], + indexing_technique=args["indexing_technique"], + account=current_user, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -142,42 +139,36 @@ class DatasetApi(Resource): if dataset is None: raise NotFound("Dataset not found.") try: - DatasetService.check_dataset_permission( - dataset, current_user) + DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) data = marshal(dataset, dataset_detail_fields) - if data.get('permission') == 'partial_members': + if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - data.update({'partial_member_list': part_users_list}) + data.update({"partial_member_list": part_users_list}) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data['indexing_technique'] == 'high_quality': + if data["indexing_technique"] == "high_quality": item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: - data['embedding_available'] = True + data["embedding_available"] = True else: - data['embedding_available'] = False + data["embedding_available"] = False else: - data['embedding_available'] = True + data["embedding_available"] = True - if data.get('permission') == 'partial_members': + if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - data.update({'partial_member_list': part_users_list}) + data.update({"partial_member_list": part_users_list}) return data, 200 @@ -191,42 +182,49 @@ class DatasetApi(Resource): raise NotFound("Dataset not found.") parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('description', - location='json', store_missing=False, - type=_validate_description_length) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help='Invalid indexing technique.') - parser.add_argument('permission', type=str, location='json', choices=( - DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.' - ) - parser.add_argument('embedding_model', type=str, - location='json', help='Invalid embedding model.') - parser.add_argument('embedding_model_provider', type=str, - location='json', help='Invalid embedding model provider.') - parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') - parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.') + parser.add_argument( + "name", + nullable=False, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", + ) + parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + ) + parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") + parser.add_argument( + "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." + ) + parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") + parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") args = parser.parse_args() data = request.get_json() # check embedding model setting - if data.get('indexing_technique') == 'high_quality': - DatasetService.check_embedding_model_setting(dataset.tenant_id, - data.get('embedding_model_provider'), - data.get('embedding_model') - ) + if data.get("indexing_technique") == "high_quality": + DatasetService.check_embedding_model_setting( + dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") + ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( - current_user, dataset, data.get('permission'), data.get('partial_member_list') + current_user, dataset, data.get("permission"), data.get("partial_member_list") ) - dataset = DatasetService.update_dataset( - dataset_id_str, args, current_user) + dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) if dataset is None: raise NotFound("Dataset not found.") @@ -234,16 +232,19 @@ class DatasetApi(Resource): result_data = marshal(dataset, dataset_detail_fields) tenant_id = current_user.current_tenant_id - if data.get('partial_member_list') and data.get('permission') == 'partial_members': + if data.get("partial_member_list") and data.get("permission") == "partial_members": DatasetPermissionService.update_partial_member_list( - tenant_id, dataset_id_str, data.get('partial_member_list') + tenant_id, dataset_id_str, data.get("partial_member_list") ) # clear partial member list when permission is only_me or all_team_members - elif data.get('permission') == DatasetPermissionEnum.ONLY_ME or data.get('permission') == DatasetPermissionEnum.ALL_TEAM: + elif ( + data.get("permission") == DatasetPermissionEnum.ONLY_ME + or data.get("permission") == DatasetPermissionEnum.ALL_TEAM + ): DatasetPermissionService.clear_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - result_data.update({'partial_member_list': partial_member_list}) + result_data.update({"partial_member_list": partial_member_list}) return result_data, 200 @@ -260,12 +261,13 @@ class DatasetApi(Resource): try: if DatasetService.delete_dataset(dataset_id_str, current_user): DatasetPermissionService.clear_partial_member_list(dataset_id_str) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: raise DatasetInUseError() + class DatasetUseCheckApi(Resource): @setup_required @login_required @@ -274,10 +276,10 @@ class DatasetUseCheckApi(Resource): dataset_id_str = str(dataset_id) dataset_is_using = DatasetService.dataset_use_check(dataset_id_str) - return {'is_using': dataset_is_using}, 200 + return {"is_using": dataset_is_using}, 200 + class DatasetQueryApi(Resource): - @setup_required @login_required @account_initialization_required @@ -292,51 +294,53 @@ class DatasetQueryApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) - dataset_queries, total = DatasetService.get_dataset_queries( - dataset_id=dataset.id, - page=page, - per_page=limit - ) + dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit) response = { - 'data': marshal(dataset_queries, dataset_query_detail_fields), - 'has_more': len(dataset_queries) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(dataset_queries, dataset_query_detail_fields), + "has_more": len(dataset_queries) == limit, + "limit": limit, + "total": total, + "page": page, } return response, 200 class DatasetIndexingEstimateApi(Resource): - @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('indexing_technique', type=str, required=True, - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') + parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument( + "indexing_technique", + type=str, + required=True, + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + location="json", + ) + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) extract_settings = [] - if args['info_list']['data_source_type'] == 'upload_file': - file_ids = args['info_list']['file_info_list']['file_ids'] - file_details = db.session.query(UploadFile).filter( - UploadFile.tenant_id == current_user.current_tenant_id, - UploadFile.id.in_(file_ids) - ).all() + if args["info_list"]["data_source_type"] == "upload_file": + file_ids = args["info_list"]["file_info_list"]["file_ids"] + file_details = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) + .all() + ) if file_details is None: raise NotFound("File not found.") @@ -344,55 +348,58 @@ class DatasetIndexingEstimateApi(Resource): if file_details: for file_detail in file_details: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=args['doc_form'] + datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"] ) extract_settings.append(extract_setting) - elif args['info_list']['data_source_type'] == 'notion_import': - notion_info_list = args['info_list']['notion_info_list'] + elif args["info_list"]["data_source_type"] == "notion_import": + notion_info_list = args["info_list"]["notion_info_list"] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] - for page in notion_info['pages']: + workspace_id = notion_info["workspace_id"] + for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ "notion_workspace_id": workspace_id, - "notion_obj_id": page['page_id'], - "notion_page_type": page['type'], - "tenant_id": current_user.current_tenant_id + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) - elif args['info_list']['data_source_type'] == 'website_crawl': - website_info_list = args['info_list']['website_info_list'] - for url in website_info_list['urls']: + elif args["info_list"]["data_source_type"] == "website_crawl": + website_info_list = args["info_list"]["website_info_list"] + for url in website_info_list["urls"]: extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": website_info_list['provider'], - "job_id": website_info_list['job_id'], + "provider": website_info_list["provider"], + "job_id": website_info_list["job_id"], "url": url, "tenant_id": current_user.current_tenant_id, - "mode": 'crawl', - "only_main_content": website_info_list['only_main_content'] + "mode": "crawl", + "only_main_content": website_info_list["only_main_content"], }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) else: - raise ValueError('Data source type not support') + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - args['process_rule'], args['doc_form'], - args['doc_language'], args['dataset_id'], - args['indexing_technique']) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + args["process_rule"], + args["doc_form"], + args["doc_language"], + args["dataset_id"], + args["indexing_technique"], + ) except LLMBadRequestError: raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -402,7 +409,6 @@ class DatasetIndexingEstimateApi(Resource): class DatasetRelatedAppListApi(Resource): - @setup_required @login_required @account_initialization_required @@ -426,52 +432,52 @@ class DatasetRelatedAppListApi(Resource): if app_model: related_apps.append(app_model) - return { - 'data': related_apps, - 'total': len(related_apps) - }, 200 + return {"data": related_apps, "total": len(related_apps)}, 200 class DatasetIndexingStatusApi(Resource): - @setup_required @login_required @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.tenant_id == current_user.current_tenant_id - ).all() + documents = ( + db.session.query(Document) + .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) + .all() + ) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data class DatasetApiKeyApi(Resource): max_keys = 10 - token_prefix = 'dataset-' - resource_type = 'dataset' + token_prefix = "dataset-" + resource_type = "dataset" @setup_required @login_required @account_initialization_required @marshal_with(api_key_list) def get(self): - keys = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ - all() + keys = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .all() + ) return {"items": keys} @setup_required @@ -483,15 +489,17 @@ class DatasetApiKeyApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - current_key_count = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ - count() + current_key_count = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .count() + ) if current_key_count >= self.max_keys: flask_restful.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", - code='max_keys_exceeded' + code="max_keys_exceeded", ) key = ApiToken.generate_api_key(self.token_prefix, 24) @@ -505,7 +513,7 @@ class DatasetApiKeyApi(Resource): class DatasetApiDeleteApi(Resource): - resource_type = 'dataset' + resource_type = "dataset" @setup_required @login_required @@ -517,18 +525,23 @@ class DatasetApiDeleteApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - key = db.session.query(ApiToken). \ - filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type, - ApiToken.id == api_key_id). \ - first() + key = ( + db.session.query(ApiToken) + .filter( + ApiToken.tenant_id == current_user.current_tenant_id, + ApiToken.type == self.resource_type, + ApiToken.id == api_key_id, + ) + .first() + ) if key is None: - flask_restful.abort(404, message='API key not found') + flask_restful.abort(404, message="API key not found") db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DatasetApiBaseUrlApi(Resource): @@ -537,8 +550,10 @@ class DatasetApiBaseUrlApi(Resource): @account_initialization_required def get(self): return { - 'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL - else request.host_url.rstrip('/')) + '/v1' + "api_base_url": ( + dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/") + ) + + "/v1" } @@ -549,15 +564,26 @@ class DatasetRetrievalSettingApi(Resource): def get(self): vector_type = dify_config.VECTOR_STORE match vector_type: - case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT: + case ( + VectorType.MILVUS + | VectorType.RELYT + | VectorType.PGVECTOR + | VectorType.TIDB_VECTOR + | VectorType.CHROMA + | VectorType.TENCENT + ): + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + case ( + VectorType.QDRANT + | VectorType.WEAVIATE + | VectorType.OPENSEARCH + | VectorType.ANALYTICDB + | VectorType.MYSCALE + | VectorType.ORACLE + | VectorType.ELASTICSEARCH + ): return { - 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH.value - ] - } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH: - return { - 'retrieval_method': [ + "retrieval_method": [ RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value, @@ -573,15 +599,27 @@ class DatasetRetrievalSettingMockApi(Resource): @account_initialization_required def get(self, vector_type): match vector_type: - case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS: + case ( + VectorType.MILVUS + | VectorType.RELYT + | VectorType.TIDB_VECTOR + | VectorType.CHROMA + | VectorType.TENCENT + | VectorType.PGVECTO_RS + ): + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + case ( + VectorType.QDRANT + | VectorType.WEAVIATE + | VectorType.OPENSEARCH + | VectorType.ANALYTICDB + | VectorType.MYSCALE + | VectorType.ORACLE + | VectorType.ELASTICSEARCH + | VectorType.PGVECTOR + ): return { - 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH.value - ] - } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR: - return { - 'retrieval_method': [ + "retrieval_method": [ RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value, @@ -591,7 +629,6 @@ class DatasetRetrievalSettingMockApi(Resource): raise ValueError(f"Unsupported vector db type {vector_type}.") - class DatasetErrorDocs(Resource): @setup_required @login_required @@ -603,10 +640,7 @@ class DatasetErrorDocs(Resource): raise NotFound("Dataset not found.") results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str) - return { - 'data': [marshal(item, document_status_fields) for item in results], - 'total': len(results) - }, 200 + return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200 class DatasetPermissionUserListApi(Resource): @@ -626,21 +660,21 @@ class DatasetPermissionUserListApi(Resource): partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) return { - 'data': partial_members_list, + "data": partial_members_list, }, 200 -api.add_resource(DatasetListApi, '/datasets') -api.add_resource(DatasetApi, '/datasets/') -api.add_resource(DatasetUseCheckApi, '/datasets//use-check') -api.add_resource(DatasetQueryApi, '/datasets//queries') -api.add_resource(DatasetErrorDocs, '/datasets//error-docs') -api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') -api.add_resource(DatasetRelatedAppListApi, '/datasets//related-apps') -api.add_resource(DatasetIndexingStatusApi, '/datasets//indexing-status') -api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') -api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/') -api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') -api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') -api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/') -api.add_resource(DatasetPermissionUserListApi, '/datasets//permission-part-users') +api.add_resource(DatasetListApi, "/datasets") +api.add_resource(DatasetApi, "/datasets/") +api.add_resource(DatasetUseCheckApi, "/datasets//use-check") +api.add_resource(DatasetQueryApi, "/datasets//queries") +api.add_resource(DatasetErrorDocs, "/datasets//error-docs") +api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate") +api.add_resource(DatasetRelatedAppListApi, "/datasets//related-apps") +api.add_resource(DatasetIndexingStatusApi, "/datasets//indexing-status") +api.add_resource(DatasetApiKeyApi, "/datasets/api-keys") +api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/") +api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") +api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") +api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/") +api.add_resource(DatasetPermissionUserListApi, "/datasets//permission-part-users") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 976b97660a..7d0b9f0460 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -57,7 +57,7 @@ class DocumentResource(Resource): def get_document(self, dataset_id: str, document_id: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -67,17 +67,17 @@ class DocumentResource(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise Forbidden('No permission.') + raise Forbidden("No permission.") return document def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -87,7 +87,7 @@ class DocumentResource(Resource): documents = DocumentService.get_batch_documents(dataset_id, batch) if not documents: - raise NotFound('Documents not found.') + raise NotFound("Documents not found.") return documents @@ -99,11 +99,11 @@ class GetProcessRuleApi(Resource): def get(self): req_data = request.args - document_id = req_data.get('document_id') + document_id = req_data.get("document_id") # get default rules - mode = DocumentService.DEFAULT_RULES['mode'] - rules = DocumentService.DEFAULT_RULES['rules'] + mode = DocumentService.DEFAULT_RULES["mode"] + rules = DocumentService.DEFAULT_RULES["rules"] if document_id: # get the latest process rule document = Document.query.get_or_404(document_id) @@ -111,7 +111,7 @@ class GetProcessRuleApi(Resource): dataset = DatasetService.get_dataset(document.dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -119,19 +119,18 @@ class GetProcessRuleApi(Resource): raise Forbidden(str(e)) # get the latest process rule - dataset_process_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.dataset_id == document.dataset_id). \ - order_by(DatasetProcessRule.created_at.desc()). \ - limit(1). \ - one_or_none() + dataset_process_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.dataset_id == document.dataset_id) + .order_by(DatasetProcessRule.created_at.desc()) + .limit(1) + .one_or_none() + ) if dataset_process_rule: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict - return { - 'mode': mode, - 'rules': rules - } + return {"mode": mode, "rules": rules} class DatasetDocumentListApi(Resource): @@ -140,49 +139,48 @@ class DatasetDocumentListApi(Resource): @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - search = request.args.get('keyword', default=None, type=str) - sort = request.args.get('sort', default='-created_at', type=str) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + search = request.args.get("keyword", default=None, type=str) + sort = request.args.get("sort", default="-created_at", type=str) # "yes", "true", "t", "y", "1" convert to True, while others convert to False. try: - fetch = string_to_bool(request.args.get('fetch', default='false')) + fetch = string_to_bool(request.args.get("fetch", default="false")) except (ArgumentTypeError, ValueError, Exception) as e: fetch = False dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = Document.query.filter_by( - dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) + query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) if search: - search = f'%{search}%' + search = f"%{search}%" query = query.filter(Document.name.like(search)) - if sort.startswith('-'): + if sort.startswith("-"): sort_logic = desc sort = sort[1:] else: sort_logic = asc - if sort == 'hit_count': - sub_query = db.select(DocumentSegment.document_id, - db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) \ - .group_by(DocumentSegment.document_id) \ + if sort == "hit_count": + sub_query = ( + db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) + .group_by(DocumentSegment.document_id) .subquery() + ) - query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \ - .order_by( - sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), - sort_logic(Document.position), - ) - elif sort == 'created_at': + query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( + sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), + sort_logic(Document.position), + ) + elif sort == "created_at": query = query.order_by( sort_logic(Document.created_at), sort_logic(Document.position), @@ -193,48 +191,47 @@ class DatasetDocumentListApi(Resource): desc(Document.position), ) - paginated_documents = query.paginate( - page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items if fetch: for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments data = marshal(documents, document_with_segments_fields) else: data = marshal(documents, document_fields) response = { - 'data': data, - 'has_more': len(documents) == limit, - 'limit': limit, - 'total': paginated_documents.total, - 'page': page + "data": data, + "has_more": len(documents) == limit, + "limit": limit, + "total": paginated_documents.total, + "page": page, } return response - documents_and_batch_fields = { - 'documents': fields.List(fields.Nested(document_fields)), - 'batch': fields.String - } + documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String} @setup_required @login_required @account_initialization_required @marshal_with(documents_and_batch_fields) - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def post(self, dataset_id): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_dataset_editor: @@ -246,21 +243,22 @@ class DatasetDocumentListApi(Resource): raise Forbidden(str(e)) parser = reqparse.RequestParser() - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, - location='json') - parser.add_argument('data_source', type=dict, required=False, location='json') - parser.add_argument('process_rule', type=dict, required=False, location='json') - parser.add_argument('duplicate', type=bool, default=True, nullable=False, location='json') - parser.add_argument('original_document_id', type=str, required=False, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" + ) + parser.add_argument("data_source", type=dict, required=False, location="json") + parser.add_argument("process_rule", type=dict, required=False, location="json") + parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") + parser.add_argument("original_document_id", type=str, required=False, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() - if not dataset.indexing_technique and not args['indexing_technique']: - raise ValueError('indexing_technique is required.') + if not dataset.indexing_technique and not args["indexing_technique"]: + raise ValueError("indexing_technique is required.") # validate args DocumentService.document_create_args_validate(args) @@ -274,51 +272,53 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return { - 'documents': documents, - 'batch': batch - } + return {"documents": documents, "batch": batch} class DatasetInitApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(dataset_and_document_fields) - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def post(self): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, required=True, - nullable=False, location='json') - parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument( + "indexing_technique", + type=str, + choices=Dataset.INDEXING_TECHNIQUE_LIST, + required=True, + nullable=False, + location="json", + ) + parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() - if args['indexing_technique'] == 'high_quality': + if args["indexing_technique"] == "high_quality": try: model_manager = ModelManager() model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -327,9 +327,7 @@ class DatasetInitApi(Resource): try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( - tenant_id=current_user.current_tenant_id, - document_data=args, - account=current_user + tenant_id=current_user.current_tenant_id, document_data=args, account=current_user ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -338,17 +336,12 @@ class DatasetInitApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - response = { - 'dataset': dataset, - 'documents': documents, - 'batch': batch - } + response = {"dataset": dataset, "documents": documents, "batch": batch} return response class DocumentIndexingEstimateApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -357,50 +350,49 @@ class DocumentIndexingEstimateApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - if document.indexing_status in ['completed', 'error']: + if document.indexing_status in ["completed", "error"]: raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule data_process_rule_dict = data_process_rule.to_dict() - response = { - "tokens": 0, - "total_price": 0, - "currency": "USD", - "total_segments": 0, - "preview": [] - } + response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == document.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: - raise NotFound('File not found.') + raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file, - document_model=document.doc_form + datasource_type="upload_file", upload_file=file, document_model=document.doc_form ) indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting], - data_process_rule_dict, document.doc_form, - 'English', dataset_id) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + [extract_setting], + data_process_rule_dict, + document.doc_form, + "English", + dataset_id, + ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -410,7 +402,6 @@ class DocumentIndexingEstimateApi(DocumentResource): class DocumentBatchIndexingEstimateApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -418,13 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): dataset_id = str(dataset_id) batch = str(batch) documents = self.get_batch_documents(dataset_id, batch) - response = { - "tokens": 0, - "total_price": 0, - "currency": "USD", - "total_segments": 0, - "preview": [] - } + response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} if not documents: return response data_process_rule = documents[0].dataset_process_rule @@ -432,82 +417,83 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): info_list = [] extract_settings = [] for document in documents: - if document.indexing_status in ['completed', 'error']: + if document.indexing_status in ["completed", "error"]: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict # format document files info - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] info_list.append(file_id) # format document notion info - elif data_source_info and 'notion_workspace_id' in data_source_info and 'notion_page_id' in data_source_info: + elif ( + data_source_info and "notion_workspace_id" in data_source_info and "notion_page_id" in data_source_info + ): pages = [] - page = { - 'page_id': data_source_info['notion_page_id'], - 'type': data_source_info['type'] - } + page = {"page_id": data_source_info["notion_page_id"], "type": data_source_info["type"]} pages.append(page) - notion_info = { - 'workspace_id': data_source_info['notion_workspace_id'], - 'pages': pages - } + notion_info = {"workspace_id": data_source_info["notion_workspace_id"], "pages": pages} info_list.append(notion_info) - if document.data_source_type == 'upload_file': - file_id = data_source_info['upload_file_id'] - file_detail = db.session.query(UploadFile).filter( - UploadFile.tenant_id == current_user.current_tenant_id, - UploadFile.id == file_id - ).first() + if document.data_source_type == "upload_file": + file_id = data_source_info["upload_file_id"] + file_detail = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) + .first() + ) if file_detail is None: raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=document.doc_form + datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form ) extract_settings.append(extract_setting) - elif document.data_source_type == 'notion_import': + elif document.data_source_type == "notion_import": extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ - "notion_workspace_id": data_source_info['notion_workspace_id'], - "notion_obj_id": data_source_info['notion_page_id'], - "notion_page_type": data_source_info['type'], - "tenant_id": current_user.current_tenant_id + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=document.doc_form + document_model=document.doc_form, ) extract_settings.append(extract_setting) - elif document.data_source_type == 'website_crawl': + elif document.data_source_type == "website_crawl": extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": data_source_info['provider'], - "job_id": data_source_info['job_id'], - "url": data_source_info['url'], + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "url": data_source_info["url"], "tenant_id": current_user.current_tenant_id, - "mode": data_source_info['mode'], - "only_main_content": data_source_info['only_main_content'] + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], }, - document_model=document.doc_form + document_model=document.doc_form, ) extract_settings.append(extract_setting) else: - raise ValueError('Data source type not support') + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - data_process_rule_dict, document.doc_form, - 'English', dataset_id) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + data_process_rule_dict, + document.doc_form, + "English", + dataset_id, + ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -516,7 +502,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): class DocumentBatchIndexingStatusApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -526,24 +511,24 @@ class DocumentBatchIndexingStatusApi(DocumentResource): documents = self.get_batch_documents(dataset_id, batch) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data class DocumentIndexingStatusApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -552,25 +537,24 @@ class DocumentIndexingStatusApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - completed_segments = DocumentSegment.query \ - .filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id), - DocumentSegment.status != 're_segment') \ - .count() - total_segments = DocumentSegment.query \ - .filter(DocumentSegment.document_id == str(document_id), - DocumentSegment.status != 're_segment') \ - .count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" return marshal(document, document_status_fields) class DocumentDetailApi(DocumentResource): - METADATA_CHOICES = {'all', 'only', 'without'} + METADATA_CHOICES = {"all", "only", "without"} @setup_required @login_required @@ -580,77 +564,73 @@ class DocumentDetailApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - metadata = request.args.get('metadata', 'all') + metadata = request.args.get("metadata", "all") if metadata not in self.METADATA_CHOICES: - raise InvalidMetadataError(f'Invalid metadata value: {metadata}') + raise InvalidMetadataError(f"Invalid metadata value: {metadata}") - if metadata == 'only': - response = { - 'id': document.id, - 'doc_type': document.doc_type, - 'doc_metadata': document.doc_metadata - } - elif metadata == 'without': + if metadata == "only": + response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} + elif metadata == "without": process_rules = DatasetService.get_process_rules(dataset_id) data_source_info = document.data_source_detail_dict response = { - 'id': document.id, - 'position': document.position, - 'data_source_type': document.data_source_type, - 'data_source_info': data_source_info, - 'dataset_process_rule_id': document.dataset_process_rule_id, - 'dataset_process_rule': process_rules, - 'name': document.name, - 'created_from': document.created_from, - 'created_by': document.created_by, - 'created_at': document.created_at.timestamp(), - 'tokens': document.tokens, - 'indexing_status': document.indexing_status, - 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None, - 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None, - 'indexing_latency': document.indexing_latency, - 'error': document.error, - 'enabled': document.enabled, - 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None, - 'disabled_by': document.disabled_by, - 'archived': document.archived, - 'segment_count': document.segment_count, - 'average_segment_length': document.average_segment_length, - 'hit_count': document.hit_count, - 'display_status': document.display_status, - 'doc_form': document.doc_form + "id": document.id, + "position": document.position, + "data_source_type": document.data_source_type, + "data_source_info": data_source_info, + "dataset_process_rule_id": document.dataset_process_rule_id, + "dataset_process_rule": process_rules, + "name": document.name, + "created_from": document.created_from, + "created_by": document.created_by, + "created_at": document.created_at.timestamp(), + "tokens": document.tokens, + "indexing_status": document.indexing_status, + "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None, + "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None, + "indexing_latency": document.indexing_latency, + "error": document.error, + "enabled": document.enabled, + "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None, + "disabled_by": document.disabled_by, + "archived": document.archived, + "segment_count": document.segment_count, + "average_segment_length": document.average_segment_length, + "hit_count": document.hit_count, + "display_status": document.display_status, + "doc_form": document.doc_form, } else: process_rules = DatasetService.get_process_rules(dataset_id) data_source_info = document.data_source_detail_dict response = { - 'id': document.id, - 'position': document.position, - 'data_source_type': document.data_source_type, - 'data_source_info': data_source_info, - 'dataset_process_rule_id': document.dataset_process_rule_id, - 'dataset_process_rule': process_rules, - 'name': document.name, - 'created_from': document.created_from, - 'created_by': document.created_by, - 'created_at': document.created_at.timestamp(), - 'tokens': document.tokens, - 'indexing_status': document.indexing_status, - 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None, - 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None, - 'indexing_latency': document.indexing_latency, - 'error': document.error, - 'enabled': document.enabled, - 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None, - 'disabled_by': document.disabled_by, - 'archived': document.archived, - 'doc_type': document.doc_type, - 'doc_metadata': document.doc_metadata, - 'segment_count': document.segment_count, - 'average_segment_length': document.average_segment_length, - 'hit_count': document.hit_count, - 'display_status': document.display_status, - 'doc_form': document.doc_form + "id": document.id, + "position": document.position, + "data_source_type": document.data_source_type, + "data_source_info": data_source_info, + "dataset_process_rule_id": document.dataset_process_rule_id, + "dataset_process_rule": process_rules, + "name": document.name, + "created_from": document.created_from, + "created_by": document.created_by, + "created_at": document.created_at.timestamp(), + "tokens": document.tokens, + "indexing_status": document.indexing_status, + "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None, + "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None, + "indexing_latency": document.indexing_latency, + "error": document.error, + "enabled": document.enabled, + "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None, + "disabled_by": document.disabled_by, + "archived": document.archived, + "doc_type": document.doc_type, + "doc_metadata": document.doc_metadata, + "segment_count": document.segment_count, + "average_segment_length": document.average_segment_length, + "hit_count": document.hit_count, + "display_status": document.display_status, + "doc_form": document.doc_form, } return response, 200 @@ -671,7 +651,7 @@ class DocumentProcessingApi(DocumentResource): if action == "pause": if document.indexing_status != "indexing": - raise InvalidActionError('Document not in indexing state.') + raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -680,7 +660,7 @@ class DocumentProcessingApi(DocumentResource): elif action == "resume": if document.indexing_status not in ["paused", "error"]: - raise InvalidActionError('Document not in paused or error state.') + raise InvalidActionError("Document not in paused or error state.") document.paused_by = None document.paused_at = None @@ -689,7 +669,7 @@ class DocumentProcessingApi(DocumentResource): else: raise InvalidActionError() - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DocumentDeleteApi(DocumentResource): @@ -710,9 +690,9 @@ class DocumentDeleteApi(DocumentResource): try: DocumentService.delete_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentMetadataApi(DocumentResource): @@ -726,26 +706,26 @@ class DocumentMetadataApi(DocumentResource): req_data = request.get_json() - doc_type = req_data.get('doc_type') - doc_metadata = req_data.get('doc_metadata') + doc_type = req_data.get("doc_type") + doc_metadata = req_data.get("doc_metadata") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() if doc_type is None or doc_metadata is None: - raise ValueError('Both doc_type and doc_metadata must be provided.') + raise ValueError("Both doc_type and doc_metadata must be provided.") if doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA: - raise ValueError('Invalid doc_type.') + raise ValueError("Invalid doc_type.") if not isinstance(doc_metadata, dict): - raise ValueError('doc_metadata must be a dictionary.') + raise ValueError("doc_metadata must be a dictionary.") metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] document.doc_metadata = {} - if doc_type == 'others': + if doc_type == "others": document.doc_metadata = doc_metadata else: for key, value_type in metadata_schema.items(): @@ -757,14 +737,14 @@ class DocumentMetadataApi(DocumentResource): document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success', 'message': 'Document metadata updated.'}, 200 + return {"result": "success", "message": "Document metadata updated."}, 200 class DocumentStatusApi(DocumentResource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) document_id = str(document_id) @@ -784,14 +764,14 @@ class DocumentStatusApi(DocumentResource): document = self.get_document(dataset_id, document_id) - indexing_cache_key = 'document_{}_indexing'.format(document.id) + indexing_cache_key = "document_{}_indexing".format(document.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") if action == "enable": if document.enabled: - raise InvalidActionError('Document already enabled.') + raise InvalidActionError("Document already enabled.") document.enabled = True document.disabled_at = None @@ -804,13 +784,13 @@ class DocumentStatusApi(DocumentResource): add_document_to_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "disable": - if not document.completed_at or document.indexing_status != 'completed': - raise InvalidActionError('Document is not completed.') + if not document.completed_at or document.indexing_status != "completed": + raise InvalidActionError("Document is not completed.") if not document.enabled: - raise InvalidActionError('Document already disabled.') + raise InvalidActionError("Document already disabled.") document.enabled = False document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -823,11 +803,11 @@ class DocumentStatusApi(DocumentResource): remove_document_from_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "archive": if document.archived: - raise InvalidActionError('Document already archived.') + raise InvalidActionError("Document already archived.") document.archived = True document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -841,10 +821,10 @@ class DocumentStatusApi(DocumentResource): remove_document_from_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "un_archive": if not document.archived: - raise InvalidActionError('Document is not archived.') + raise InvalidActionError("Document is not archived.") document.archived = False document.archived_at = None @@ -857,13 +837,12 @@ class DocumentStatusApi(DocumentResource): add_document_to_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 else: raise InvalidActionError() class DocumentPauseApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -874,7 +853,7 @@ class DocumentPauseApi(DocumentResource): dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document = DocumentService.get_document(dataset.id, document_id) @@ -890,9 +869,9 @@ class DocumentPauseApi(DocumentResource): # pause document DocumentService.pause_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot pause completed document.') + raise DocumentIndexingError("Cannot pause completed document.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRecoverApi(DocumentResource): @@ -905,7 +884,7 @@ class DocumentRecoverApi(DocumentResource): document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document = DocumentService.get_document(dataset.id, document_id) # 404 if document not found @@ -919,9 +898,9 @@ class DocumentRecoverApi(DocumentResource): # pause document DocumentService.recover_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Document is not in paused status.') + raise DocumentIndexingError("Document is not in paused status.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRetryApi(DocumentResource): @@ -932,15 +911,14 @@ class DocumentRetryApi(DocumentResource): """retry document.""" parser = reqparse.RequestParser() - parser.add_argument('document_ids', type=list, required=True, nullable=False, - location='json') + parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) retry_documents = [] if not dataset: - raise NotFound('Dataset not found.') - for document_id in args['document_ids']: + raise NotFound("Dataset not found.") + for document_id in args["document_ids"]: try: document_id = str(document_id) @@ -955,7 +933,7 @@ class DocumentRetryApi(DocumentResource): raise ArchivedDocumentImmutableError() # 400 if document is completed - if document.indexing_status == 'completed': + if document.indexing_status == "completed": raise DocumentAlreadyFinishedError() retry_documents.append(document) except Exception as e: @@ -964,7 +942,7 @@ class DocumentRetryApi(DocumentResource): # retry document DocumentService.retry_document(dataset_id, retry_documents) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRenameApi(DocumentResource): @@ -979,13 +957,13 @@ class DocumentRenameApi(DocumentResource): dataset = DatasetService.get_dataset(dataset_id) DatasetService.check_dataset_operator_permission(current_user, dataset) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() try: - document = DocumentService.rename_document(dataset_id, document_id, args['name']) + document = DocumentService.rename_document(dataset_id, document_id, args["name"]) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") return document @@ -999,51 +977,43 @@ class WebsiteDocumentSyncApi(DocumentResource): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise Forbidden('No permission.') - if document.data_source_type != 'website_crawl': - raise ValueError('Document is not a website document.') + raise Forbidden("No permission.") + if document.data_source_type != "website_crawl": + raise ValueError("Document is not a website document.") # 403 if document is archived if DocumentService.check_archived(document): raise ArchivedDocumentImmutableError() # sync document DocumentService.sync_website_document(dataset_id, document) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(GetProcessRuleApi, '/datasets/process-rule') -api.add_resource(DatasetDocumentListApi, - '/datasets//documents') -api.add_resource(DatasetInitApi, - '/datasets/init') -api.add_resource(DocumentIndexingEstimateApi, - '/datasets//documents//indexing-estimate') -api.add_resource(DocumentBatchIndexingEstimateApi, - '/datasets//batch//indexing-estimate') -api.add_resource(DocumentBatchIndexingStatusApi, - '/datasets//batch//indexing-status') -api.add_resource(DocumentIndexingStatusApi, - '/datasets//documents//indexing-status') -api.add_resource(DocumentDetailApi, - '/datasets//documents/') -api.add_resource(DocumentProcessingApi, - '/datasets//documents//processing/') -api.add_resource(DocumentDeleteApi, - '/datasets//documents/') -api.add_resource(DocumentMetadataApi, - '/datasets//documents//metadata') -api.add_resource(DocumentStatusApi, - '/datasets//documents//status/') -api.add_resource(DocumentPauseApi, '/datasets//documents//processing/pause') -api.add_resource(DocumentRecoverApi, '/datasets//documents//processing/resume') -api.add_resource(DocumentRetryApi, '/datasets//retry') -api.add_resource(DocumentRenameApi, - '/datasets//documents//rename') +api.add_resource(GetProcessRuleApi, "/datasets/process-rule") +api.add_resource(DatasetDocumentListApi, "/datasets//documents") +api.add_resource(DatasetInitApi, "/datasets/init") +api.add_resource( + DocumentIndexingEstimateApi, "/datasets//documents//indexing-estimate" +) +api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets//batch//indexing-estimate") +api.add_resource(DocumentBatchIndexingStatusApi, "/datasets//batch//indexing-status") +api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") +api.add_resource(DocumentDetailApi, "/datasets//documents/") +api.add_resource( + DocumentProcessingApi, "/datasets//documents//processing/" +) +api.add_resource(DocumentDeleteApi, "/datasets//documents/") +api.add_resource(DocumentMetadataApi, "/datasets//documents//metadata") +api.add_resource(DocumentStatusApi, "/datasets//documents//status/") +api.add_resource(DocumentPauseApi, "/datasets//documents//processing/pause") +api.add_resource(DocumentRecoverApi, "/datasets//documents//processing/resume") +api.add_resource(DocumentRetryApi, "/datasets//retry") +api.add_resource(DocumentRenameApi, "/datasets//documents//rename") -api.add_resource(WebsiteDocumentSyncApi, '/datasets//documents//website-sync') +api.add_resource(WebsiteDocumentSyncApi, "/datasets//documents//website-sync") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index a4210d5a0c..2405649387 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -40,7 +40,7 @@ class DatasetDocumentSegmentListApi(Resource): document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -50,37 +50,33 @@ class DatasetDocumentSegmentListApi(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") parser = reqparse.RequestParser() - parser.add_argument('last_id', type=str, default=None, location='args') - parser.add_argument('limit', type=int, default=20, location='args') - parser.add_argument('status', type=str, - action='append', default=[], location='args') - parser.add_argument('hit_count_gte', type=int, - default=None, location='args') - parser.add_argument('enabled', type=str, default='all', location='args') - parser.add_argument('keyword', type=str, default=None, location='args') + parser.add_argument("last_id", type=str, default=None, location="args") + parser.add_argument("limit", type=int, default=20, location="args") + parser.add_argument("status", type=str, action="append", default=[], location="args") + parser.add_argument("hit_count_gte", type=int, default=None, location="args") + parser.add_argument("enabled", type=str, default="all", location="args") + parser.add_argument("keyword", type=str, default=None, location="args") args = parser.parse_args() - last_id = args['last_id'] - limit = min(args['limit'], 100) - status_list = args['status'] - hit_count_gte = args['hit_count_gte'] - keyword = args['keyword'] + last_id = args["last_id"] + limit = min(args["limit"], 100) + status_list = args["status"] + hit_count_gte = args["hit_count_gte"] + keyword = args["keyword"] query = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id ) if last_id is not None: last_segment = db.session.get(DocumentSegment, str(last_id)) if last_segment: - query = query.filter( - DocumentSegment.position > last_segment.position) + query = query.filter(DocumentSegment.position > last_segment.position) else: - return {'data': [], 'has_more': False, 'limit': limit}, 200 + return {"data": [], "has_more": False, "limit": limit}, 200 if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) @@ -89,12 +85,12 @@ class DatasetDocumentSegmentListApi(Resource): query = query.filter(DocumentSegment.hit_count >= hit_count_gte) if keyword: - query = query.where(DocumentSegment.content.ilike(f'%{keyword}%')) + query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) - if args['enabled'].lower() != 'all': - if args['enabled'].lower() == 'true': + if args["enabled"].lower() != "all": + if args["enabled"].lower() == "true": query = query.filter(DocumentSegment.enabled == True) - elif args['enabled'].lower() == 'false': + elif args["enabled"].lower() == "false": query = query.filter(DocumentSegment.enabled == False) total = query.count() @@ -106,11 +102,11 @@ class DatasetDocumentSegmentListApi(Resource): segments = segments[:-1] return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form, - 'has_more': has_more, - 'limit': limit, - 'total': total + "data": marshal(segments, segment_fields), + "doc_form": document.doc_form, + "has_more": has_more, + "limit": limit, + "total": total, }, 200 @@ -118,12 +114,12 @@ class DatasetDocumentSegmentApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, segment_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # The role of the current user in the ta table must be admin, owner, or editor @@ -134,7 +130,7 @@ class DatasetDocumentSegmentApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -142,32 +138,32 @@ class DatasetDocumentSegmentApi(Resource): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") - if segment.status != 'completed': - raise NotFound('Segment is not completed, enable or disable function is not allowed') + if segment.status != "completed": + raise NotFound("Segment is not completed, enable or disable function is not allowed") - document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id) + document_indexing_cache_key = "document_{}_indexing".format(segment.document_id) cache_result = redis_client.get(document_indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise InvalidActionError("Segment is being indexed, please try again later") @@ -186,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource): enable_segment_to_index_task.delay(segment.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "disable": if not segment.enabled: raise InvalidActionError("Segment is already disabled.") @@ -201,7 +197,7 @@ class DatasetDocumentSegmentApi(Resource): disable_segment_from_index_task.delay(segment.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 else: raise InvalidActionError() @@ -210,35 +206,36 @@ class DatasetDocumentSegmentAddApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') - @cloud_edition_billing_knowledge_limit_check('add_segment') + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if not current_user.is_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) try: @@ -247,37 +244,34 @@ class DatasetDocumentSegmentAddApi(Resource): raise Forbidden(str(e)) # validate args parser = reqparse.RequestParser() - parser.add_argument('content', type=str, required=True, nullable=False, location='json') - parser.add_argument('answer', type=str, required=False, nullable=True, location='json') - parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + parser.add_argument("answer", type=str, required=False, nullable=True, location="json") + parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) segment = SegmentService.create_segment(args, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 class DatasetDocumentSegmentUpdateApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') - if dataset.indexing_technique == 'high_quality': + raise NotFound("Document not found.") + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -285,22 +279,22 @@ class DatasetDocumentSegmentUpdateApi(Resource): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() @@ -310,16 +304,13 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise Forbidden(str(e)) # validate args parser = reqparse.RequestParser() - parser.add_argument('content', type=str, required=True, nullable=False, location='json') - parser.add_argument('answer', type=str, required=False, nullable=True, location='json') - parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + parser.add_argument("answer", type=str, required=False, nullable=True, location="json") + parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) segment = SegmentService.update_segment(args, segment, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @login_required @@ -329,22 +320,21 @@ class DatasetDocumentSegmentUpdateApi(Resource): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin or owner if not current_user.is_editor: raise Forbidden() @@ -353,36 +343,36 @@ class DatasetDocumentSegmentUpdateApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) SegmentService.delete_segment(segment, document, dataset) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DatasetDocumentSegmentBatchImportApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') - @cloud_edition_billing_knowledge_limit_check('add_segment') + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: @@ -390,51 +380,47 @@ class DatasetDocumentSegmentBatchImportApi(Resource): df = pd.read_csv(file) result = [] for index, row in df.iterrows(): - if document.doc_form == 'qa_model': - data = {'content': row[0], 'answer': row[1]} + if document.doc_form == "qa_model": + data = {"content": row[0], "answer": row[1]} else: - data = {'content': row[0]} + data = {"content": row[0]} result.append(data) if len(result) == 0: raise ValueError("The CSV file is empty.") # async job job_id = str(uuid.uuid4()) - indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "segment_batch_import_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(indexing_cache_key, 'waiting') - batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id, - current_user.current_tenant_id, current_user.id) + redis_client.setnx(indexing_cache_key, "waiting") + batch_create_segment_to_index_task.delay( + str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id + ) except Exception as e: - return {'error': str(e)}, 500 - return { - 'job_id': job_id, - 'job_status': 'waiting' - }, 200 + return {"error": str(e)}, 500 + return {"job_id": job_id, "job_status": "waiting"}, 200 @setup_required @login_required @account_initialization_required def get(self, job_id): job_id = str(job_id) - indexing_cache_key = 'segment_batch_import_{}'.format(job_id) + indexing_cache_key = "segment_batch_import_{}".format(job_id) cache_result = redis_client.get(indexing_cache_key) if cache_result is None: raise ValueError("The job is not exist.") - return { - 'job_id': job_id, - 'job_status': cache_result.decode() - }, 200 + return {"job_id": job_id, "job_status": cache_result.decode()}, 200 -api.add_resource(DatasetDocumentSegmentListApi, - '/datasets//documents//segments') -api.add_resource(DatasetDocumentSegmentApi, - '/datasets//segments//') -api.add_resource(DatasetDocumentSegmentAddApi, - '/datasets//documents//segment') -api.add_resource(DatasetDocumentSegmentUpdateApi, - '/datasets//documents//segments/') -api.add_resource(DatasetDocumentSegmentBatchImportApi, - '/datasets//documents//segments/batch_import', - '/datasets/batch_import_status/') +api.add_resource(DatasetDocumentSegmentListApi, "/datasets//documents//segments") +api.add_resource(DatasetDocumentSegmentApi, "/datasets//segments//") +api.add_resource(DatasetDocumentSegmentAddApi, "/datasets//documents//segment") +api.add_resource( + DatasetDocumentSegmentUpdateApi, + "/datasets//documents//segments/", +) +api.add_resource( + DatasetDocumentSegmentBatchImportApi, + "/datasets//documents//segments/batch_import", + "/datasets/batch_import_status/", +) diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index 9270b610c2..6a7a3971a8 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -2,90 +2,90 @@ from libs.exception import BaseHTTPException class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = 'high_quality_dataset_only' + error_code = "high_quality_dataset_only" description = "Current operation only supports 'high-quality' datasets." code = 400 class DatasetNotInitializedError(BaseHTTPException): - error_code = 'dataset_not_initialized' + error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." code = 400 class ArchivedDocumentImmutableError(BaseHTTPException): - error_code = 'archived_document_immutable' + error_code = "archived_document_immutable" description = "The archived document is not editable." code = 403 class DatasetNameDuplicateError(BaseHTTPException): - error_code = 'dataset_name_duplicate' + error_code = "dataset_name_duplicate" description = "The dataset name already exists. Please modify your dataset name." code = 409 class InvalidActionError(BaseHTTPException): - error_code = 'invalid_action' + error_code = "invalid_action" description = "Invalid action." code = 400 class DocumentAlreadyFinishedError(BaseHTTPException): - error_code = 'document_already_finished' + error_code = "document_already_finished" description = "The document has been processed. Please refresh the page or go to the document details." code = 400 class DocumentIndexingError(BaseHTTPException): - error_code = 'document_indexing' + error_code = "document_indexing" description = "The document is being processed and cannot be edited." code = 400 class InvalidMetadataError(BaseHTTPException): - error_code = 'invalid_metadata' + error_code = "invalid_metadata" description = "The metadata content is incorrect. Please check and verify." code = 400 class WebsiteCrawlError(BaseHTTPException): - error_code = 'crawl_failed' + error_code = "crawl_failed" description = "{message}" code = 500 class DatasetInUseError(BaseHTTPException): - error_code = 'dataset_in_use' + error_code = "dataset_in_use" description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." code = 409 class IndexingEstimateError(BaseHTTPException): - error_code = 'indexing_estimate_error' + error_code = "indexing_estimate_error" description = "Knowledge indexing estimate failed: {message}" code = 500 diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index 3b2083bcc3..d6a464545e 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -21,7 +21,6 @@ PREVIEW_WORDS_LIMIT = 3000 class FileApi(Resource): - @setup_required @login_required @account_initialization_required @@ -31,23 +30,22 @@ class FileApi(Resource): batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT return { - 'file_size_limit': file_size_limit, - 'batch_count_limit': batch_count_limit, - 'image_file_size_limit': image_file_size_limit + "file_size_limit": file_size_limit, + "batch_count_limit": batch_count_limit, + "image_file_size_limit": image_file_size_limit, }, 200 @setup_required @login_required @account_initialization_required @marshal_with(file_fields) - @cloud_edition_billing_resource_check(resource='documents') + @cloud_edition_billing_resource_check(resource="documents") def post(self): - # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: @@ -69,7 +67,7 @@ class FilePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) text = FileService.get_file_preview(file_id) - return {'content': text} + return {"content": text} class FileSupportTypeApi(Resource): @@ -78,10 +76,10 @@ class FileSupportTypeApi(Resource): @account_initialization_required def get(self): etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS - return {'allowed_extensions': allowed_extensions} + allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS + return {"allowed_extensions": allowed_extensions} -api.add_resource(FileApi, '/files/upload') -api.add_resource(FilePreviewApi, '/files//preview') -api.add_resource(FileSupportTypeApi, '/files/support-type') +api.add_resource(FileApi, "/files/upload") +api.add_resource(FilePreviewApi, "/files//preview") +api.add_resource(FileSupportTypeApi, "/files/support-type") diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 8771bf909e..0b4a7be986 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -29,7 +29,6 @@ from services.hit_testing_service import HitTestingService class HitTestingApi(Resource): - @setup_required @login_required @account_initialization_required @@ -46,8 +45,8 @@ class HitTestingApi(Resource): raise Forbidden(str(e)) parser = reqparse.RequestParser() - parser.add_argument('query', type=str, location='json') - parser.add_argument('retrieval_model', type=dict, required=False, location='json') + parser.add_argument("query", type=str, location="json") + parser.add_argument("retrieval_model", type=dict, required=False, location="json") args = parser.parse_args() HitTestingService.hit_testing_args_check(args) @@ -55,13 +54,13 @@ class HitTestingApi(Resource): try: response = HitTestingService.retrieve( dataset=dataset, - query=args['query'], + query=args["query"], account=current_user, - retrieval_model=args['retrieval_model'], - limit=10 + retrieval_model=args["retrieval_model"], + limit=10, ) - return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} + return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} except services.errors.index.IndexNotInitializedError: raise DatasetNotInitializedError() except ProviderTokenNotInitError as ex: @@ -73,7 +72,8 @@ class HitTestingApi(Resource): except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model or Reranking Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except InvokeError as e: raise CompletionRequestError(e.description) except ValueError as e: @@ -83,4 +83,4 @@ class HitTestingApi(Resource): raise InternalServerError(str(e)) -api.add_resource(HitTestingApi, '/datasets//hit-testing') +api.add_resource(HitTestingApi, "/datasets//hit-testing") diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index bbd91256f1..cb54f1aacb 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -9,16 +9,14 @@ from services.website_service import WebsiteService class WebsiteCrawlApi(Resource): - @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, choices=['firecrawl'], - required=True, nullable=True, location='json') - parser.add_argument('url', type=str, required=True, nullable=True, location='json') - parser.add_argument('options', type=dict, required=True, nullable=True, location='json') + parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json") + parser.add_argument("url", type=str, required=True, nullable=True, location="json") + parser.add_argument("options", type=dict, required=True, nullable=True, location="json") args = parser.parse_args() WebsiteService.document_create_args_validate(args) # crawl url @@ -35,15 +33,15 @@ class WebsiteCrawlStatusApi(Resource): @account_initialization_required def get(self, job_id: str): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args') + parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args") args = parser.parse_args() # get crawl status try: - result = WebsiteService.get_crawl_status(job_id, args['provider']) + result = WebsiteService.get_crawl_status(job_id, args["provider"]) except Exception as e: raise WebsiteCrawlError(str(e)) return result, 200 -api.add_resource(WebsiteCrawlApi, '/website/crawl') -api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/') +api.add_resource(WebsiteCrawlApi, "/website/crawl") +api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/") diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index 888dad83cc..1c70ea6c59 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -2,35 +2,41 @@ from libs.exception import BaseHTTPException class AlreadySetupError(BaseHTTPException): - error_code = 'already_setup' + error_code = "already_setup" description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage." code = 403 class NotSetupError(BaseHTTPException): - error_code = 'not_setup' - description = "Dify has not been initialized and installed yet. " \ - "Please proceed with the initialization and installation process first." + error_code = "not_setup" + description = ( + "Dify has not been initialized and installed yet. " + "Please proceed with the initialization and installation process first." + ) code = 401 + class NotInitValidateError(BaseHTTPException): - error_code = 'not_init_validated' - description = "Init validation has not been completed yet. " \ - "Please proceed with the init validation process first." + error_code = "not_init_validated" + description = ( + "Init validation has not been completed yet. " "Please proceed with the init validation process first." + ) code = 401 + class InitValidateFailedError(BaseHTTPException): - error_code = 'init_validate_failed' + error_code = "init_validate_failed" description = "Init validation failed. Please check the password and try again." code = 401 + class AccountNotLinkTenantError(BaseHTTPException): - error_code = 'account_not_link_tenant' + error_code = "account_not_link_tenant" description = "Account not link tenant." code = 403 class AlreadyActivateError(BaseHTTPException): - error_code = 'already_activate' + error_code = "already_activate" description = "Auth Token is invalid or account already activated, please check again." code = 403 diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 27cc83042a..71cb060ecc 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -33,14 +33,10 @@ class ChatAudioApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=None - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -76,30 +72,31 @@ class ChatTextApi(InstalledAppResource): app_model = installed_app.app try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None - response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - voice=voice, - text=text - ) + response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text) return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -127,7 +124,7 @@ class ChatTextApi(InstalledAppResource): raise InternalServerError() -api.add_resource(ChatAudioApi, '/installed-apps//audio-to-text', endpoint='installed_app_audio') -api.add_resource(ChatTextApi, '/installed-apps//text-to-audio', endpoint='installed_app_text') +api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") +api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") # api.add_resource(ChatTextApiWithMessageId, '/installed-apps//text-to-audio/message-id', # endpoint='installed_app_text_with_message_id') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 869b56e13b..c039e8bca5 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -30,33 +30,28 @@ from services.app_generate_service import AppGenerateService # define completion api for user class CompletionApi(InstalledAppResource): - def post(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=streaming + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming ) return helper.compact_generate_response(response) @@ -85,12 +80,12 @@ class CompletionApi(InstalledAppResource): class CompletionStopApi(InstalledAppResource): def post(self, installed_app, task_id): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(InstalledAppResource): @@ -101,25 +96,21 @@ class ChatApi(InstalledAppResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") args = parser.parse_args() - args['auto_generate_name'] = False + args["auto_generate_name"] = False installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) return helper.compact_generate_response(response) @@ -154,10 +145,22 @@ class ChatStopApi(InstalledAppResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/installed-apps//completion-messages', endpoint='installed_app_completion') -api.add_resource(CompletionStopApi, '/installed-apps//completion-messages//stop', endpoint='installed_app_stop_completion') -api.add_resource(ChatApi, '/installed-apps//chat-messages', endpoint='installed_app_chat_completion') -api.add_resource(ChatStopApi, '/installed-apps//chat-messages//stop', endpoint='installed_app_stop_chat_completion') +api.add_resource( + CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" +) +api.add_resource( + CompletionStopApi, + "/installed-apps//completion-messages//stop", + endpoint="installed_app_stop_completion", +) +api.add_resource( + ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" +) +api.add_resource( + ChatStopApi, + "/installed-apps//chat-messages//stop", + endpoint="installed_app_stop_chat_completion", +) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index ea0fa4e17e..2918024b64 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -16,7 +16,6 @@ from services.web_conversation_service import WebConversationService class ConversationListApi(InstalledAppResource): - @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, installed_app): app_model = installed_app.app @@ -25,21 +24,21 @@ class ConversationListApi(InstalledAppResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") args = parser.parse_args() pinned = None - if 'pinned' in args and args['pinned'] is not None: - pinned = True if args['pinned'] == 'true' else False + if "pinned" in args and args["pinned"] is not None: + pinned = True if args["pinned"] == "true" else False try: return WebConversationService.pagination_by_last_id( app_model=app_model, user=current_user, - last_id=args['last_id'], - limit=args['limit'], + last_id=args["last_id"], + limit=args["limit"], invoke_from=InvokeFrom.EXPLORE, pinned=pinned, ) @@ -65,7 +64,6 @@ class ConversationApi(InstalledAppResource): class ConversationRenameApi(InstalledAppResource): - @marshal_with(simple_conversation_fields) def post(self, installed_app, c_id): app_model = installed_app.app @@ -76,24 +74,19 @@ class ConversationRenameApi(InstalledAppResource): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: return ConversationService.rename( - app_model, - conversation_id, - current_user, - args['name'], - args['auto_generate'] + app_model, conversation_id, current_user, args["name"], args["auto_generate"] ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") class ConversationPinApi(InstalledAppResource): - def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) @@ -123,8 +116,26 @@ class ConversationUnPinApi(InstalledAppResource): return {"result": "success"} -api.add_resource(ConversationRenameApi, '/installed-apps//conversations//name', endpoint='installed_app_conversation_rename') -api.add_resource(ConversationListApi, '/installed-apps//conversations', endpoint='installed_app_conversations') -api.add_resource(ConversationApi, '/installed-apps//conversations/', endpoint='installed_app_conversation') -api.add_resource(ConversationPinApi, '/installed-apps//conversations//pin', endpoint='installed_app_conversation_pin') -api.add_resource(ConversationUnPinApi, '/installed-apps//conversations//unpin', endpoint='installed_app_conversation_unpin') +api.add_resource( + ConversationRenameApi, + "/installed-apps//conversations//name", + endpoint="installed_app_conversation_rename", +) +api.add_resource( + ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" +) +api.add_resource( + ConversationApi, + "/installed-apps//conversations/", + endpoint="installed_app_conversation", +) +api.add_resource( + ConversationPinApi, + "/installed-apps//conversations//pin", + endpoint="installed_app_conversation_pin", +) +api.add_resource( + ConversationUnPinApi, + "/installed-apps//conversations//unpin", + endpoint="installed_app_conversation_unpin", +) diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index 9c3216ecc8..18221b7797 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -2,24 +2,24 @@ from libs.exception import BaseHTTPException class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Not Completion App" code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "App mode is invalid." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Only support workflow app." code = 400 class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): - error_code = 'app_suggested_questions_after_answer_disabled' + error_code = "app_suggested_questions_after_answer_disabled" description = "Function Suggested questions after answer disabled." code = 403 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index ec7bbed307..b71078760c 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -21,72 +21,71 @@ class InstalledAppsListApi(Resource): @marshal_with(installed_app_list_fields) def get(self): current_tenant_id = current_user.current_tenant_id - installed_apps = db.session.query(InstalledApp).filter( - InstalledApp.tenant_id == current_tenant_id - ).all() + installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) installed_apps = [ { - 'id': installed_app.id, - 'app': installed_app.app, - 'app_owner_tenant_id': installed_app.app_owner_tenant_id, - 'is_pinned': installed_app.is_pinned, - 'last_used_at': installed_app.last_used_at, - 'editable': current_user.role in ["owner", "admin"], - 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id + "id": installed_app.id, + "app": installed_app.app, + "app_owner_tenant_id": installed_app.app_owner_tenant_id, + "is_pinned": installed_app.is_pinned, + "last_used_at": installed_app.last_used_at, + "editable": current_user.role in ["owner", "admin"], + "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, } for installed_app in installed_apps ] - installed_apps.sort(key=lambda app: (-app['is_pinned'], - app['last_used_at'] is None, - -app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0)) + installed_apps.sort( + key=lambda app: ( + -app["is_pinned"], + app["last_used_at"] is None, + -app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0, + ) + ) - return {'installed_apps': installed_apps} + return {"installed_apps": installed_apps} @login_required @account_initialization_required - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, help='Invalid app_id') + parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") args = parser.parse_args() - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() if recommended_app is None: - raise NotFound('App not found') + raise NotFound("App not found") current_tenant_id = current_user.current_tenant_id - app = db.session.query(App).filter( - App.id == args['app_id'] - ).first() + app = db.session.query(App).filter(App.id == args["app_id"]).first() if app is None: - raise NotFound('App not found') + raise NotFound("App not found") if not app.is_public: - raise Forbidden('You can\'t install a non-public app') + raise Forbidden("You can't install a non-public app") - installed_app = InstalledApp.query.filter(and_( - InstalledApp.app_id == args['app_id'], - InstalledApp.tenant_id == current_tenant_id - )).first() + installed_app = InstalledApp.query.filter( + and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id) + ).first() if installed_app is None: # todo: position recommended_app.install_count += 1 new_installed_app = InstalledApp( - app_id=args['app_id'], + app_id=args["app_id"], tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, - last_used_at=datetime.now(timezone.utc).replace(tzinfo=None) + last_used_at=datetime.now(timezone.utc).replace(tzinfo=None), ) db.session.add(new_installed_app) db.session.commit() - return {'message': 'App installed successfully'} + return {"message": "App installed successfully"} class InstalledAppApi(InstalledAppResource): @@ -94,30 +93,31 @@ class InstalledAppApi(InstalledAppResource): update and delete an installed app use InstalledAppResource to apply default decorators and get installed_app """ + def delete(self, installed_app): if installed_app.app_owner_tenant_id == current_user.current_tenant_id: - raise BadRequest('You can\'t uninstall an app owned by the current tenant') + raise BadRequest("You can't uninstall an app owned by the current tenant") db.session.delete(installed_app) db.session.commit() - return {'result': 'success', 'message': 'App uninstalled successfully'} + return {"result": "success", "message": "App uninstalled successfully"} def patch(self, installed_app): parser = reqparse.RequestParser() - parser.add_argument('is_pinned', type=inputs.boolean) + parser.add_argument("is_pinned", type=inputs.boolean) args = parser.parse_args() commit_args = False - if 'is_pinned' in args: - installed_app.is_pinned = args['is_pinned'] + if "is_pinned" in args: + installed_app.is_pinned = args["is_pinned"] commit_args = True if commit_args: db.session.commit() - return {'result': 'success', 'message': 'App info updated successfully'} + return {"result": "success", "message": "App info updated successfully"} -api.add_resource(InstalledAppsListApi, '/installed-apps') -api.add_resource(InstalledAppApi, '/installed-apps/') +api.add_resource(InstalledAppsListApi, "/installed-apps") +api.add_resource(InstalledAppApi, "/installed-apps/") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 3523a86900..f5eb185172 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -44,19 +44,21 @@ class MessageListApi(InstalledAppResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, current_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: raise NotFound("First Message Not Exists.") + class MessageFeedbackApi(InstalledAppResource): def post(self, installed_app, message_id): app_model = installed_app.app @@ -64,30 +66,32 @@ class MessageFeedbackApi(InstalledAppResource): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args['rating']) + MessageService.create_feedback(app_model, message_id, current_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageMoreLikeThisApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + parser.add_argument( + "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" + ) args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate_more_like_this( @@ -95,7 +99,7 @@ class MessageMoreLikeThisApi(InstalledAppResource): user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE, - streaming=streaming + streaming=streaming, ) return helper.compact_generate_response(response) except MessageNotExistsError: @@ -128,10 +132,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=current_user, - message_id=message_id, - invoke_from=InvokeFrom.EXPLORE + app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE ) except MessageNotExistsError: raise NotFound("Message not found") @@ -151,10 +152,22 @@ class MessageSuggestedQuestionApi(InstalledAppResource): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} -api.add_resource(MessageListApi, '/installed-apps//messages', endpoint='installed_app_messages') -api.add_resource(MessageFeedbackApi, '/installed-apps//messages//feedbacks', endpoint='installed_app_message_feedback') -api.add_resource(MessageMoreLikeThisApi, '/installed-apps//messages//more-like-this', endpoint='installed_app_more_like_this') -api.add_resource(MessageSuggestedQuestionApi, '/installed-apps//messages//suggested-questions', endpoint='installed_app_suggested_question') +api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") +api.add_resource( + MessageFeedbackApi, + "/installed-apps//messages//feedbacks", + endpoint="installed_app_message_feedback", +) +api.add_resource( + MessageMoreLikeThisApi, + "/installed-apps//messages//more-like-this", + endpoint="installed_app_more_like_this", +) +api.add_resource( + MessageSuggestedQuestionApi, + "/installed-apps//messages//suggested-questions", + endpoint="installed_app_suggested_question", +) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 0a168d6306..ad55b04043 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,4 +1,3 @@ - from flask_restful import fields, marshal_with from configs import dify_config @@ -11,33 +10,32 @@ from services.app_service import AppService class AppParameterApi(InstalledAppResource): """Resource for app variables.""" + variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) + "key": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "default": fields.String, + "max_length": fields.Integer, + "options": fields.List(fields.String), } - system_parameters_fields = { - 'image_file_size_limit': fields.String - } + system_parameters_fields = {"image_file_size_limit": fields.String} parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(system_parameters_fields), } @marshal_with(parameters_fields) @@ -56,30 +54,35 @@ class AppParameterApi(InstalledAppResource): app_model_config = app_model.app_model_config features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get( + "suggested_questions_after_answer", {"enabled": False} + ), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, } @@ -90,6 +93,7 @@ class ExploreAppMetaApi(InstalledAppResource): return AppService().get_app_meta(app_model) -api.add_resource(AppParameterApi, '/installed-apps//parameters', - endpoint='installed_app_parameters') -api.add_resource(ExploreAppMetaApi, '/installed-apps//meta', endpoint='installed_app_meta') +api.add_resource( + AppParameterApi, "/installed-apps//parameters", endpoint="installed_app_parameters" +) +api.add_resource(ExploreAppMetaApi, "/installed-apps//meta", endpoint="installed_app_meta") diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 6e10e2ec92..5daaa1e7c3 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -8,28 +8,28 @@ from libs.login import login_required from services.recommended_app_service import RecommendedAppService app_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String + "id": fields.String, + "name": fields.String, + "mode": fields.String, + "icon": fields.String, + "icon_background": fields.String, } recommended_app_fields = { - 'app': fields.Nested(app_fields, attribute='app'), - 'app_id': fields.String, - 'description': fields.String(attribute='description'), - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'category': fields.String, - 'position': fields.Integer, - 'is_listed': fields.Boolean + "app": fields.Nested(app_fields, attribute="app"), + "app_id": fields.String, + "description": fields.String(attribute="description"), + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "category": fields.String, + "position": fields.Integer, + "is_listed": fields.Boolean, } recommended_app_list_fields = { - 'recommended_apps': fields.List(fields.Nested(recommended_app_fields)), - 'categories': fields.List(fields.String) + "recommended_apps": fields.List(fields.Nested(recommended_app_fields)), + "categories": fields.List(fields.String), } @@ -40,11 +40,11 @@ class RecommendedAppListApi(Resource): def get(self): # language args parser = reqparse.RequestParser() - parser.add_argument('language', type=str, location='args') + parser.add_argument("language", type=str, location="args") args = parser.parse_args() - if args.get('language') and args.get('language') in languages: - language_prefix = args.get('language') + if args.get("language") and args.get("language") in languages: + language_prefix = args.get("language") elif current_user and current_user.interface_language: language_prefix = current_user.interface_language else: @@ -61,5 +61,5 @@ class RecommendedAppApi(Resource): return RecommendedAppService.get_recommend_app_detail(app_id) -api.add_resource(RecommendedAppListApi, '/explore/apps') -api.add_resource(RecommendedAppApi, '/explore/apps/') +api.add_resource(RecommendedAppListApi, "/explore/apps") +api.add_resource(RecommendedAppApi, "/explore/apps/") diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index cf86b2fee1..a7ccf737a8 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -11,56 +11,54 @@ from libs.helper import TimestampField, uuid_value from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} message_fields = { - 'id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String, - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String, + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "created_at": TimestampField, } class SavedMessageListApi(InstalledAppResource): saved_message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit']) + return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) def post(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='json') + parser.add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() try: - SavedMessageService.save(app_model, current_user, args['message_id']) + SavedMessageService.save(app_model, current_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class SavedMessageApi(InstalledAppResource): @@ -69,13 +67,21 @@ class SavedMessageApi(InstalledAppResource): message_id = str(message_id) - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() SavedMessageService.delete(app_model, current_user, message_id) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(SavedMessageListApi, '/installed-apps//saved-messages', endpoint='installed_app_saved_messages') -api.add_resource(SavedMessageApi, '/installed-apps//saved-messages/', endpoint='installed_app_saved_message') +api.add_resource( + SavedMessageListApi, + "/installed-apps//saved-messages", + endpoint="installed_app_saved_messages", +) +api.add_resource( + SavedMessageApi, + "/installed-apps//saved-messages/", + endpoint="installed_app_saved_message", +) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 7c5e211d47..45f99b1db9 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -35,17 +35,13 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) return helper.compact_generate_response(response) @@ -76,10 +72,10 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(InstalledAppWorkflowRunApi, '/installed-apps//workflows/run') -api.add_resource(InstalledAppWorkflowTaskStopApi, '/installed-apps//workflows/tasks//stop') +api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") +api.add_resource( + InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" +) diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 84890f1b46..3c9317847b 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -14,29 +14,33 @@ def installed_app_required(view=None): def decorator(view): @wraps(view) def decorated(*args, **kwargs): - if not kwargs.get('installed_app_id'): - raise ValueError('missing installed_app_id in path parameters') + if not kwargs.get("installed_app_id"): + raise ValueError("missing installed_app_id in path parameters") - installed_app_id = kwargs.get('installed_app_id') + installed_app_id = kwargs.get("installed_app_id") installed_app_id = str(installed_app_id) - del kwargs['installed_app_id'] + del kwargs["installed_app_id"] - installed_app = db.session.query(InstalledApp).filter( - InstalledApp.id == str(installed_app_id), - InstalledApp.tenant_id == current_user.current_tenant_id - ).first() + installed_app = ( + db.session.query(InstalledApp) + .filter( + InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id + ) + .first() + ) if installed_app is None: - raise NotFound('Installed app not found') + raise NotFound("Installed app not found") if not installed_app.app: db.session.delete(installed_app) db.session.commit() - raise NotFound('Installed app not found') + raise NotFound("Installed app not found") return view(installed_app, *args, **kwargs) + return decorated if view: diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index fe73bcb985..5d6a8bf152 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -13,23 +13,18 @@ from services.code_based_extension_service import CodeBasedExtensionService class CodeBasedExtensionAPI(Resource): - @setup_required @login_required @account_initialization_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('module', type=str, required=True, location='args') + parser.add_argument("module", type=str, required=True, location="args") args = parser.parse_args() - return { - 'module': args['module'], - 'data': CodeBasedExtensionService.get_code_based_extension(args['module']) - } + return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} class APIBasedExtensionAPI(Resource): - @setup_required @login_required @account_initialization_required @@ -44,23 +39,22 @@ class APIBasedExtensionAPI(Resource): @marshal_with(api_based_extension_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('api_endpoint', type=str, required=True, location='json') - parser.add_argument('api_key', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("api_endpoint", type=str, required=True, location="json") + parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() extension_data = APIBasedExtension( tenant_id=current_user.current_tenant_id, - name=args['name'], - api_endpoint=args['api_endpoint'], - api_key=args['api_key'] + name=args["name"], + api_endpoint=args["api_endpoint"], + api_key=args["api_key"], ) return APIBasedExtensionService.save(extension_data) class APIBasedExtensionDetailAPI(Resource): - @setup_required @login_required @account_initialization_required @@ -82,16 +76,16 @@ class APIBasedExtensionDetailAPI(Resource): extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('api_endpoint', type=str, required=True, location='json') - parser.add_argument('api_key', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("api_endpoint", type=str, required=True, location="json") + parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() - extension_data_from_db.name = args['name'] - extension_data_from_db.api_endpoint = args['api_endpoint'] + extension_data_from_db.name = args["name"] + extension_data_from_db.api_endpoint = args["api_endpoint"] - if args['api_key'] != HIDDEN_VALUE: - extension_data_from_db.api_key = args['api_key'] + if args["api_key"] != HIDDEN_VALUE: + extension_data_from_db.api_key = args["api_key"] return APIBasedExtensionService.save(extension_data_from_db) @@ -106,10 +100,10 @@ class APIBasedExtensionDetailAPI(Resource): APIBasedExtensionService.delete(extension_data_from_db) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(CodeBasedExtensionAPI, '/code-based-extension') +api.add_resource(CodeBasedExtensionAPI, "/code-based-extension") -api.add_resource(APIBasedExtensionAPI, '/api-based-extension') -api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/') +api.add_resource(APIBasedExtensionAPI, "/api-based-extension") +api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/") diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 8475cd8488..f0482f749d 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -10,7 +10,6 @@ from .wraps import account_initialization_required, cloud_utm_record class FeatureApi(Resource): - @setup_required @login_required @account_initialization_required @@ -24,5 +23,5 @@ class SystemFeatureApi(Resource): return FeatureService.get_system_features().model_dump() -api.add_resource(FeatureApi, '/features') -api.add_resource(SystemFeatureApi, '/system-features') +api.add_resource(FeatureApi, "/features") +api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 6feb1003a9..7d3ae677ee 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -14,12 +14,11 @@ from .wraps import only_edition_self_hosted class InitValidateAPI(Resource): - def get(self): init_status = get_init_validate_status() if init_status: - return { 'status': 'finished' } - return {'status': 'not_started' } + return {"status": "finished"} + return {"status": "not_started"} @only_edition_self_hosted def post(self): @@ -29,22 +28,23 @@ class InitValidateAPI(Resource): raise AlreadySetupError() parser = reqparse.RequestParser() - parser.add_argument('password', type=str_len(30), - required=True, location='json') - input_password = parser.parse_args()['password'] + parser.add_argument("password", type=str_len(30), required=True, location="json") + input_password = parser.parse_args()["password"] - if input_password != os.environ.get('INIT_PASSWORD'): - session['is_init_validated'] = False + if input_password != os.environ.get("INIT_PASSWORD"): + session["is_init_validated"] = False raise InitValidateFailedError() - - session['is_init_validated'] = True - return {'result': 'success'}, 201 + + session["is_init_validated"] = True + return {"result": "success"}, 201 + def get_init_validate_status(): - if dify_config.EDITION == 'SELF_HOSTED': - if os.environ.get('INIT_PASSWORD'): - return session.get('is_init_validated') or DifySetup.query.first() - + if dify_config.EDITION == "SELF_HOSTED": + if os.environ.get("INIT_PASSWORD"): + return session.get("is_init_validated") or DifySetup.query.first() + return True -api.add_resource(InitValidateAPI, '/init') + +api.add_resource(InitValidateAPI, "/init") diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index 7664ba8c16..cd28cc946e 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -4,14 +4,11 @@ from controllers.console import api class PingApi(Resource): - def get(self): """ For connection health check """ - return { - "result": "pong" - } + return {"result": "pong"} -api.add_resource(PingApi, '/ping') +api.add_resource(PingApi, "/ping") diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index ef7cc6bc03..827695e00f 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -16,17 +16,13 @@ from .wraps import only_edition_self_hosted class SetupApi(Resource): - def get(self): - if dify_config.EDITION == 'SELF_HOSTED': + if dify_config.EDITION == "SELF_HOSTED": setup_status = get_setup_status() if setup_status: - return { - 'step': 'finished', - 'setup_at': setup_status.setup_at.isoformat() - } - return {'step': 'not_started'} - return {'step': 'finished'} + return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()} + return {"step": "not_started"} + return {"step": "finished"} @only_edition_self_hosted def post(self): @@ -38,28 +34,22 @@ class SetupApi(Resource): tenant_count = TenantService.get_tenant_count() if tenant_count > 0: raise AlreadySetupError() - + if not get_init_validate_status(): raise NotInitValidateError() parser = reqparse.RequestParser() - parser.add_argument('email', type=email, - required=True, location='json') - parser.add_argument('name', type=str_len( - 30), required=True, location='json') - parser.add_argument('password', type=valid_password, - required=True, location='json') + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("name", type=str_len(30), required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") args = parser.parse_args() # setup RegisterService.setup( - email=args['email'], - name=args['name'], - password=args['password'], - ip_address=get_remote_ip(request) + email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request) ) - return {'result': 'success'}, 201 + return {"result": "success"}, 201 def setup_required(view): @@ -68,7 +58,7 @@ def setup_required(view): # check setup if not get_init_validate_status(): raise NotInitValidateError() - + elif not get_setup_status(): raise NotSetupError() @@ -78,9 +68,10 @@ def setup_required(view): def get_setup_status(): - if dify_config.EDITION == 'SELF_HOSTED': + if dify_config.EDITION == "SELF_HOSTED": return DifySetup.query.first() else: return True -api.add_resource(SetupApi, '/setup') + +api.add_resource(SetupApi, "/setup") diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 004afaa531..7293aeeb34 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -14,19 +14,18 @@ from services.tag_service import TagService def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 50 characters.') + raise ValueError("Name must be between 1 to 50 characters.") return name class TagListApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(tag_fields) def get(self): - tag_type = request.args.get('type', type=str) - keyword = request.args.get('keyword', default=None, type=str) + tag_type = request.args.get("type", type=str) + keyword = request.args.get("keyword", default=None, type=str) tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) return tags, 200 @@ -40,28 +39,21 @@ class TagListApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='Name must be between 1 to 50 characters.', - type=_validate_name) - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument( + "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name + ) + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() tag = TagService.save_tags(args) - response = { - 'id': tag.id, - 'name': tag.name, - 'type': tag.type, - 'binding_count': 0 - } + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 class TagUpdateDeleteApi(Resource): - @setup_required @login_required @account_initialization_required @@ -72,20 +64,15 @@ class TagUpdateDeleteApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='Name must be between 1 to 50 characters.', - type=_validate_name) + parser.add_argument( + "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name + ) args = parser.parse_args() tag = TagService.update_tags(args, tag_id) binding_count = TagService.get_tag_binding_count(tag_id) - response = { - 'id': tag.id, - 'name': tag.name, - 'type': tag.type, - 'binding_count': binding_count - } + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} return response, 200 @@ -104,7 +91,6 @@ class TagUpdateDeleteApi(Resource): class TagBindingCreateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -114,14 +100,15 @@ class TagBindingCreateApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json', - help='Tag IDs is required.') - parser.add_argument('target_id', type=str, nullable=False, required=True, location='json', - help='Target ID is required.') - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument( + "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." + ) + parser.add_argument( + "target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required." + ) + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() TagService.save_tag_binding(args) @@ -129,7 +116,6 @@ class TagBindingCreateApi(Resource): class TagBindingDeleteApi(Resource): - @setup_required @login_required @account_initialization_required @@ -139,21 +125,18 @@ class TagBindingDeleteApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('tag_id', type=str, nullable=False, required=True, - help='Tag ID is required.') - parser.add_argument('target_id', type=str, nullable=False, required=True, - help='Target ID is required.') - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") + parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() TagService.delete_tag_binding(args) return 200 -api.add_resource(TagListApi, '/tags') -api.add_resource(TagUpdateDeleteApi, '/tags/') -api.add_resource(TagBindingCreateApi, '/tag-bindings/create') -api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove') +api.add_resource(TagListApi, "/tags") +api.add_resource(TagUpdateDeleteApi, "/tags/") +api.add_resource(TagBindingCreateApi, "/tag-bindings/create") +api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove") diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 1fcf4bdc00..76adbfe6a9 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -1,4 +1,3 @@ - import json import logging @@ -11,42 +10,39 @@ from . import api class VersionApi(Resource): - def get(self): parser = reqparse.RequestParser() - parser.add_argument('current_version', type=str, required=True, location='args') + parser.add_argument("current_version", type=str, required=True, location="args") args = parser.parse_args() check_update_url = dify_config.CHECK_UPDATE_URL result = { - 'version': dify_config.CURRENT_VERSION, - 'release_date': '', - 'release_notes': '', - 'can_auto_update': False, - 'features': { - 'can_replace_logo': dify_config.CAN_REPLACE_LOGO, - 'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED - } + "version": dify_config.CURRENT_VERSION, + "release_date": "", + "release_notes": "", + "can_auto_update": False, + "features": { + "can_replace_logo": dify_config.CAN_REPLACE_LOGO, + "model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED, + }, } if not check_update_url: return result try: - response = requests.get(check_update_url, { - 'current_version': args.get('current_version') - }) + response = requests.get(check_update_url, {"current_version": args.get("current_version")}) except Exception as error: logging.warning("Check update version error: {}.".format(str(error))) - result['version'] = args.get('current_version') + result["version"] = args.get("current_version") return result content = json.loads(response.content) - result['version'] = content['version'] - result['release_date'] = content['releaseDate'] - result['release_notes'] = content['releaseNotes'] - result['can_auto_update'] = content['canAutoUpdate'] + result["version"] = content["version"] + result["release_date"] = content["releaseDate"] + result["release_notes"] = content["releaseNotes"] + result["can_auto_update"] = content["canAutoUpdate"] return result -api.add_resource(VersionApi, '/version') +api.add_resource(VersionApi, "/version") diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 1056d5eb62..dec426128f 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -26,52 +26,53 @@ from services.errors.account import CurrentPasswordIncorrectError as ServiceCurr class AccountInitApi(Resource): - @setup_required @login_required def post(self): account = current_user - if account.status == 'active': + if account.status == "active": raise AccountAlreadyInitedError() parser = reqparse.RequestParser() - if dify_config.EDITION == 'CLOUD': - parser.add_argument('invitation_code', type=str, location='json') + if dify_config.EDITION == "CLOUD": + parser.add_argument("invitation_code", type=str, location="json") - parser.add_argument( - 'interface_language', type=supported_language, required=True, location='json') - parser.add_argument('timezone', type=timezone, - required=True, location='json') + parser.add_argument("interface_language", type=supported_language, required=True, location="json") + parser.add_argument("timezone", type=timezone, required=True, location="json") args = parser.parse_args() - if dify_config.EDITION == 'CLOUD': - if not args['invitation_code']: - raise ValueError('invitation_code is required') + if dify_config.EDITION == "CLOUD": + if not args["invitation_code"]: + raise ValueError("invitation_code is required") # check invitation code - invitation_code = db.session.query(InvitationCode).filter( - InvitationCode.code == args['invitation_code'], - InvitationCode.status == 'unused', - ).first() + invitation_code = ( + db.session.query(InvitationCode) + .filter( + InvitationCode.code == args["invitation_code"], + InvitationCode.status == "unused", + ) + .first() + ) if not invitation_code: raise InvalidInvitationCodeError() - invitation_code.status = 'used' + invitation_code.status = "used" invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id - account.interface_language = args['interface_language'] - account.timezone = args['timezone'] - account.interface_theme = 'light' - account.status = 'active' + account.interface_language = args["interface_language"] + account.timezone = args["timezone"] + account.interface_theme = "light" + account.status = "active" account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success'} + return {"result": "success"} class AccountProfileApi(Resource): @@ -90,15 +91,14 @@ class AccountNameApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() # Validate account name length - if len(args['name']) < 3 or len(args['name']) > 30: - raise ValueError( - "Account name must be between 3 and 30 characters.") + if len(args["name"]) < 3 or len(args["name"]) > 30: + raise ValueError("Account name must be between 3 and 30 characters.") - updated_account = AccountService.update_account(current_user, name=args['name']) + updated_account = AccountService.update_account(current_user, name=args["name"]) return updated_account @@ -110,10 +110,10 @@ class AccountAvatarApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('avatar', type=str, required=True, location='json') + parser.add_argument("avatar", type=str, required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, avatar=args['avatar']) + updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) return updated_account @@ -125,11 +125,10 @@ class AccountInterfaceLanguageApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument( - 'interface_language', type=supported_language, required=True, location='json') + parser.add_argument("interface_language", type=supported_language, required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, interface_language=args['interface_language']) + updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) return updated_account @@ -141,11 +140,10 @@ class AccountInterfaceThemeApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('interface_theme', type=str, choices=[ - 'light', 'dark'], required=True, location='json') + parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme']) + updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) return updated_account @@ -157,15 +155,14 @@ class AccountTimezoneApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('timezone', type=str, - required=True, location='json') + parser.add_argument("timezone", type=str, required=True, location="json") args = parser.parse_args() # Validate timezone string, e.g. America/New_York, Asia/Shanghai - if args['timezone'] not in pytz.all_timezones: + if args["timezone"] not in pytz.all_timezones: raise ValueError("Invalid timezone string.") - updated_account = AccountService.update_account(current_user, timezone=args['timezone']) + updated_account = AccountService.update_account(current_user, timezone=args["timezone"]) return updated_account @@ -177,20 +174,16 @@ class AccountPasswordApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('password', type=str, - required=False, location='json') - parser.add_argument('new_password', type=str, - required=True, location='json') - parser.add_argument('repeat_new_password', type=str, - required=True, location='json') + parser.add_argument("password", type=str, required=False, location="json") + parser.add_argument("new_password", type=str, required=True, location="json") + parser.add_argument("repeat_new_password", type=str, required=True, location="json") args = parser.parse_args() - if args['new_password'] != args['repeat_new_password']: + if args["new_password"] != args["repeat_new_password"]: raise RepeatPasswordNotMatchError() try: - AccountService.update_account_password( - current_user, args['password'], args['new_password']) + AccountService.update_account_password(current_user, args["password"], args["new_password"]) except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() @@ -199,14 +192,14 @@ class AccountPasswordApi(Resource): class AccountIntegrateApi(Resource): integrate_fields = { - 'provider': fields.String, - 'created_at': TimestampField, - 'is_bound': fields.Boolean, - 'link': fields.String + "provider": fields.String, + "created_at": TimestampField, + "is_bound": fields.Boolean, + "link": fields.String, } integrate_list_fields = { - 'data': fields.List(fields.Nested(integrate_fields)), + "data": fields.List(fields.Nested(integrate_fields)), } @setup_required @@ -216,10 +209,9 @@ class AccountIntegrateApi(Resource): def get(self): account = current_user - account_integrates = db.session.query(AccountIntegrate).filter( - AccountIntegrate.account_id == account.id).all() + account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() - base_url = request.url_root.rstrip('/') + base_url = request.url_root.rstrip("/") oauth_base_path = "/console/api/oauth/login" providers = ["github", "google"] @@ -227,36 +219,38 @@ class AccountIntegrateApi(Resource): for provider in providers: existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) if existing_integrate: - integrate_data.append({ - 'id': existing_integrate.id, - 'provider': provider, - 'created_at': existing_integrate.created_at, - 'is_bound': True, - 'link': None - }) + integrate_data.append( + { + "id": existing_integrate.id, + "provider": provider, + "created_at": existing_integrate.created_at, + "is_bound": True, + "link": None, + } + ) else: - integrate_data.append({ - 'id': None, - 'provider': provider, - 'created_at': None, - 'is_bound': False, - 'link': f'{base_url}{oauth_base_path}/{provider}' - }) - - return {'data': integrate_data} - + integrate_data.append( + { + "id": None, + "provider": provider, + "created_at": None, + "is_bound": False, + "link": f"{base_url}{oauth_base_path}/{provider}", + } + ) + return {"data": integrate_data} # Register API resources -api.add_resource(AccountInitApi, '/account/init') -api.add_resource(AccountProfileApi, '/account/profile') -api.add_resource(AccountNameApi, '/account/name') -api.add_resource(AccountAvatarApi, '/account/avatar') -api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language') -api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme') -api.add_resource(AccountTimezoneApi, '/account/timezone') -api.add_resource(AccountPasswordApi, '/account/password') -api.add_resource(AccountIntegrateApi, '/account/integrates') +api.add_resource(AccountInitApi, "/account/init") +api.add_resource(AccountProfileApi, "/account/profile") +api.add_resource(AccountNameApi, "/account/name") +api.add_resource(AccountAvatarApi, "/account/avatar") +api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language") +api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") +api.add_resource(AccountTimezoneApi, "/account/timezone") +api.add_resource(AccountPasswordApi, "/account/password") +api.add_resource(AccountIntegrateApi, "/account/integrates") # api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') diff --git a/api/controllers/console/workspace/error.py b/api/controllers/console/workspace/error.py index 99f55835bc..9e13c7b924 100644 --- a/api/controllers/console/workspace/error.py +++ b/api/controllers/console/workspace/error.py @@ -2,36 +2,36 @@ from libs.exception import BaseHTTPException class RepeatPasswordNotMatchError(BaseHTTPException): - error_code = 'repeat_password_not_match' + error_code = "repeat_password_not_match" description = "New password and repeat password does not match." code = 400 class CurrentPasswordIncorrectError(BaseHTTPException): - error_code = 'current_password_incorrect' + error_code = "current_password_incorrect" description = "Current password is incorrect." code = 400 class ProviderRequestFailedError(BaseHTTPException): - error_code = 'provider_request_failed' + error_code = "provider_request_failed" description = None code = 400 class InvalidInvitationCodeError(BaseHTTPException): - error_code = 'invalid_invitation_code' + error_code = "invalid_invitation_code" description = "Invalid invitation code." code = 400 class AccountAlreadyInitedError(BaseHTTPException): - error_code = 'account_already_inited' + error_code = "account_already_inited" description = "The account has been initialized. Please refresh the page." code = 400 class AccountNotInitializedError(BaseHTTPException): - error_code = 'account_not_initialized' + error_code = "account_not_initialized" description = "The account has not been initialized yet. Please proceed with the initialization process first." code = 400 diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 50514e39f6..771a866624 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -22,10 +22,16 @@ class LoadBalancingCredentialsValidateApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() # validate model load balancing credentials @@ -38,18 +44,18 @@ class LoadBalancingCredentialsValidateApi(Resource): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response @@ -65,10 +71,16 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() # validate model load balancing config credentials @@ -81,26 +93,30 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'], + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], config_id=config_id, ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response # Load Balancing Config -api.add_resource(LoadBalancingCredentialsValidateApi, - '/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate') +api.add_resource( + LoadBalancingCredentialsValidateApi, + "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate", +) -api.add_resource(LoadBalancingConfigCredentialsValidateApi, - '/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate') +api.add_resource( + LoadBalancingConfigCredentialsValidateApi, + "/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate", +) diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 34e9da3841..3e87bebf59 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -23,7 +23,7 @@ class MemberListApi(Resource): @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_tenant_members(current_user.current_tenant) - return {'result': 'success', 'accounts': members}, 200 + return {"result": "success", "accounts": members}, 200 class MemberInviteEmailApi(Resource): @@ -32,48 +32,46 @@ class MemberInviteEmailApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('members') + @cloud_edition_billing_resource_check("members") def post(self): parser = reqparse.RequestParser() - parser.add_argument('emails', type=str, required=True, location='json', action='append') - parser.add_argument('role', type=str, required=True, default='admin', location='json') - parser.add_argument('language', type=str, required=False, location='json') + parser.add_argument("emails", type=str, required=True, location="json", action="append") + parser.add_argument("role", type=str, required=True, default="admin", location="json") + parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() - invitee_emails = args['emails'] - invitee_role = args['role'] - interface_language = args['language'] + invitee_emails = args["emails"] + invitee_role = args["role"] + interface_language = args["language"] if not TenantAccountRole.is_non_owner_role(invitee_role): - return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 + return {"code": "invalid-role", "message": "Invalid role"}, 400 inviter = current_user invitation_results = [] console_web_url = dify_config.CONSOLE_WEB_URL for invitee_email in invitee_emails: try: - token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter) - invitation_results.append({ - 'status': 'success', - 'email': invitee_email, - 'url': f'{console_web_url}/activate?email={invitee_email}&token={token}' - }) + token = RegisterService.invite_new_member( + inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter + ) + invitation_results.append( + { + "status": "success", + "email": invitee_email, + "url": f"{console_web_url}/activate?email={invitee_email}&token={token}", + } + ) except AccountAlreadyInTenantError: - invitation_results.append({ - 'status': 'success', - 'email': invitee_email, - 'url': f'{console_web_url}/signin' - }) + invitation_results.append( + {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"} + ) break except Exception as e: - invitation_results.append({ - 'status': 'failed', - 'email': invitee_email, - 'message': str(e) - }) + invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)}) return { - 'result': 'success', - 'invitation_results': invitation_results, + "result": "success", + "invitation_results": invitation_results, }, 201 @@ -91,15 +89,15 @@ class MemberCancelInviteApi(Resource): try: TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) except services.errors.account.CannotOperateSelfError as e: - return {'code': 'cannot-operate-self', 'message': str(e)}, 400 + return {"code": "cannot-operate-self", "message": str(e)}, 400 except services.errors.account.NoPermissionError as e: - return {'code': 'forbidden', 'message': str(e)}, 403 + return {"code": "forbidden", "message": str(e)}, 403 except services.errors.account.MemberNotInTenantError as e: - return {'code': 'member-not-found', 'message': str(e)}, 404 + return {"code": "member-not-found", "message": str(e)}, 404 except Exception as e: raise ValueError(str(e)) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class MemberUpdateRoleApi(Resource): @@ -110,12 +108,12 @@ class MemberUpdateRoleApi(Resource): @account_initialization_required def put(self, member_id): parser = reqparse.RequestParser() - parser.add_argument('role', type=str, required=True, location='json') + parser.add_argument("role", type=str, required=True, location="json") args = parser.parse_args() - new_role = args['role'] + new_role = args["role"] if not TenantAccountRole.is_valid_role(new_role): - return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 + return {"code": "invalid-role", "message": "Invalid role"}, 400 member = db.session.get(Account, str(member_id)) if not member: @@ -128,7 +126,7 @@ class MemberUpdateRoleApi(Resource): # todo: 403 - return {'result': 'success'} + return {"result": "success"} class DatasetOperatorMemberListApi(Resource): @@ -140,11 +138,11 @@ class DatasetOperatorMemberListApi(Resource): @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_dataset_operator_members(current_user.current_tenant) - return {'result': 'success', 'accounts': members}, 200 + return {"result": "success", "accounts": members}, 200 -api.add_resource(MemberListApi, '/workspaces/current/members') -api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email') -api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/') -api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members//update-role') -api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators') +api.add_resource(MemberListApi, "/workspaces/current/members") +api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email") +api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/") +api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members//update-role") +api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators") diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index c888159f83..8c38420226 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -17,7 +17,6 @@ from services.model_provider_service import ModelProviderService class ModelProviderListApi(Resource): - @setup_required @login_required @account_initialization_required @@ -25,21 +24,23 @@ class ModelProviderListApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model_type', type=str, required=False, nullable=True, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument( + "model_type", + type=str, + required=False, + nullable=True, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() model_provider_service = ModelProviderService() - provider_list = model_provider_service.get_provider_list( - tenant_id=tenant_id, - model_type=args.get('model_type') - ) + provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type")) return jsonable_encoder({"data": provider_list}) class ModelProviderCredentialApi(Resource): - @setup_required @login_required @account_initialization_required @@ -47,25 +48,18 @@ class ModelProviderCredentialApi(Resource): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - credentials = model_provider_service.get_provider_credentials( - tenant_id=tenant_id, - provider=provider - ) + credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider) - return { - "credentials": credentials - } + return {"credentials": credentials} class ModelProviderValidateApi(Resource): - @setup_required @login_required @account_initialization_required def post(self, provider: str): - parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() tenant_id = current_user.current_tenant_id @@ -77,24 +71,21 @@ class ModelProviderValidateApi(Resource): try: model_provider_service.provider_credentials_validate( - tenant_id=tenant_id, - provider=provider, - credentials=args['credentials'] + tenant_id=tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response class ModelProviderApi(Resource): - @setup_required @login_required @account_initialization_required @@ -103,21 +94,19 @@ class ModelProviderApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() try: model_provider_service.save_provider_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider, - credentials=args['credentials'] + tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) - return {'result': 'success'}, 201 + return {"result": "success"}, 201 @setup_required @login_required @@ -127,12 +116,9 @@ class ModelProviderApi(Resource): raise Forbidden() model_provider_service = ModelProviderService() - model_provider_service.remove_provider_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider - ) + model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ModelProviderIconApi(Resource): @@ -146,16 +132,13 @@ class ModelProviderIconApi(Resource): def get(self, provider: str, icon_type: str, lang: str): model_provider_service = ModelProviderService() icon, mimetype = model_provider_service.get_model_provider_icon( - provider=provider, - icon_type=icon_type, - lang=lang + provider=provider, icon_type=icon_type, lang=lang ) return send_file(io.BytesIO(icon), mimetype=mimetype) class PreferredProviderTypeUpdateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -166,18 +149,22 @@ class PreferredProviderTypeUpdateApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False, - choices=['system', 'custom'], location='json') + parser.add_argument( + "preferred_provider_type", + type=str, + required=True, + nullable=False, + choices=["system", "custom"], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.switch_preferred_provider( - tenant_id=tenant_id, - provider=provider, - preferred_provider_type=args['preferred_provider_type'] + tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderPaymentCheckoutUrlApi(Resource): @@ -185,13 +172,15 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): @login_required @account_initialization_required def get(self, provider: str): - if provider != 'anthropic': - raise ValueError(f'provider name {provider} is invalid') + if provider != "anthropic": + raise ValueError(f"provider name {provider} is invalid") BillingService.is_tenant_owner_or_admin(current_user) - data = BillingService.get_model_provider_payment_link(provider_name=provider, - tenant_id=current_user.current_tenant_id, - account_id=current_user.id, - prefilled_email=current_user.email) + data = BillingService.get_model_provider_payment_link( + provider_name=provider, + tenant_id=current_user.current_tenant_id, + account_id=current_user.id, + prefilled_email=current_user.email, + ) return data @@ -201,10 +190,7 @@ class ModelProviderFreeQuotaSubmitApi(Resource): @account_initialization_required def post(self, provider: str): model_provider_service = ModelProviderService() - result = model_provider_service.free_quota_submit( - tenant_id=current_user.current_tenant_id, - provider=provider - ) + result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider) return result @@ -215,32 +201,36 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource): @account_initialization_required def get(self, provider: str): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=False, nullable=True, location='args') + parser.add_argument("token", type=str, required=False, nullable=True, location="args") args = parser.parse_args() model_provider_service = ModelProviderService() result = model_provider_service.free_quota_qualification_verify( - tenant_id=current_user.current_tenant_id, - provider=provider, - token=args['token'] + tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"] ) return result -api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') +api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") -api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers//credentials') -api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers//credentials/validate') -api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/') -api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers//' - '/') +api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") +api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") +api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/") +api.add_resource( + ModelProviderIconApi, "/workspaces/current/model-providers//" "/" +) -api.add_resource(PreferredProviderTypeUpdateApi, - '/workspaces/current/model-providers//preferred-provider-type') -api.add_resource(ModelProviderPaymentCheckoutUrlApi, - '/workspaces/current/model-providers//checkout-url') -api.add_resource(ModelProviderFreeQuotaSubmitApi, - '/workspaces/current/model-providers//free-quota-submit') -api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi, - '/workspaces/current/model-providers//free-quota-qualification-verify') +api.add_resource( + PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" +) +api.add_resource( + ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers//checkout-url" +) +api.add_resource( + ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers//free-quota-submit" +) +api.add_resource( + ModelProviderFreeQuotaQualificationVerifyApi, + "/workspaces/current/model-providers//free-quota-qualification-verify", +) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 69f2253e97..dc88f6b812 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -16,27 +16,29 @@ from services.model_provider_service import ModelProviderService class DefaultModelApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() default_model_entity = model_provider_service.get_default_model_of_model_type( - tenant_id=tenant_id, - model_type=args['model_type'] + tenant_id=tenant_id, model_type=args["model_type"] ) - return jsonable_encoder({ - "data": default_model_entity - }) + return jsonable_encoder({"data": default_model_entity}) @setup_required @login_required @@ -44,40 +46,39 @@ class DefaultModelApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json') + parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - model_settings = args['model_settings'] + model_settings = args["model_settings"] for model_setting in model_settings: - if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]: - raise ValueError('invalid model type') + if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]: + raise ValueError("invalid model type") - if 'provider' not in model_setting: + if "provider" not in model_setting: continue - if 'model' not in model_setting: - raise ValueError('invalid model') + if "model" not in model_setting: + raise ValueError("invalid model") try: model_provider_service.update_default_model_of_model_type( tenant_id=tenant_id, - model_type=model_setting['model_type'], - provider=model_setting['provider'], - model=model_setting['model'] + model_type=model_setting["model_type"], + provider=model_setting["provider"], + model=model_setting["model"], ) except Exception: logging.warning(f"{model_setting['model_type']} save error") - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelApi(Resource): - @setup_required @login_required @account_initialization_required @@ -85,14 +86,9 @@ class ModelProviderModelApi(Resource): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_provider( - tenant_id=tenant_id, - provider=provider - ) + models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) - return jsonable_encoder({ - "data": models - }) + return jsonable_encoder({"data": models}) @setup_required @login_required @@ -104,61 +100,66 @@ class ModelProviderModelApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json') - parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json') - parser.add_argument('config_from', type=str, required=False, nullable=True, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") + parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") args = parser.parse_args() model_load_balancing_service = ModelLoadBalancingService() - if ('load_balancing' in args and args['load_balancing'] and - 'enabled' in args['load_balancing'] and args['load_balancing']['enabled']): - if 'configs' not in args['load_balancing']: - raise ValueError('invalid load balancing configs') + if ( + "load_balancing" in args + and args["load_balancing"] + and "enabled" in args["load_balancing"] + and args["load_balancing"]["enabled"] + ): + if "configs" not in args["load_balancing"]: + raise ValueError("invalid load balancing configs") # save load balancing configs model_load_balancing_service.update_load_balancing_configs( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - configs=args['load_balancing']['configs'] + model=args["model"], + model_type=args["model_type"], + configs=args["load_balancing"]["configs"], ) # enable load balancing model_load_balancing_service.enable_model_load_balancing( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) else: # disable load balancing model_load_balancing_service.disable_model_load_balancing( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - if args.get('config_from', '') != 'predefined-model': + if args.get("config_from", "") != "predefined-model": model_provider_service = ModelProviderService() try: model_provider_service.save_model_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: + logging.exception(f"save model credentials error: {ex}") raise ValueError(str(ex)) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 @setup_required @login_required @@ -170,24 +171,26 @@ class ModelProviderModelApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.remove_model_credentials( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ModelProviderModelCredentialApi(Resource): - @setup_required @login_required @account_initialization_required @@ -195,38 +198,34 @@ class ModelProviderModelCredentialApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='args') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument("model", type=str, required=True, nullable=False, location="args") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() model_provider_service = ModelProviderService() credentials = model_provider_service.get_model_credentials( - tenant_id=tenant_id, - provider=provider, - model_type=args['model_type'], - model=args['model'] + tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] ) model_load_balancing_service = ModelLoadBalancingService() is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) return { "credentials": credentials, - "load_balancing": { - "enabled": is_load_balancing_enabled, - "configs": load_balancing_configs - } + "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, } class ModelProviderModelEnableApi(Resource): - @setup_required @login_required @account_initialization_required @@ -234,24 +233,26 @@ class ModelProviderModelEnableApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.enable_model( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelDisableApi(Resource): - @setup_required @login_required @account_initialization_required @@ -259,24 +260,26 @@ class ModelProviderModelDisableApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.disable_model( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelValidateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -284,10 +287,16 @@ class ModelProviderModelValidateApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() @@ -299,48 +308,42 @@ class ModelProviderModelValidateApi(Resource): model_provider_service.model_credentials_validate( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response class ModelProviderModelParameterRuleApi(Resource): - @setup_required @login_required @account_initialization_required def get(self, provider: str): parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='args') + parser.add_argument("model", type=str, required=True, nullable=False, location="args") args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() parameter_rules = model_provider_service.get_model_parameter_rules( - tenant_id=tenant_id, - provider=provider, - model=args['model'] + tenant_id=tenant_id, provider=provider, model=args["model"] ) - return jsonable_encoder({ - "data": parameter_rules - }) + return jsonable_encoder({"data": parameter_rules}) class ModelProviderAvailableModelApi(Resource): - @setup_required @login_required @account_initialization_required @@ -348,27 +351,31 @@ class ModelProviderAvailableModelApi(Resource): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_model_type( - tenant_id=tenant_id, - model_type=model_type - ) + models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) - return jsonable_encoder({ - "data": models - }) + return jsonable_encoder({"data": models}) -api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers//models') -api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers//models/enable', - endpoint='model-provider-model-enable') -api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers//models/disable', - endpoint='model-provider-model-disable') -api.add_resource(ModelProviderModelCredentialApi, - '/workspaces/current/model-providers//models/credentials') -api.add_resource(ModelProviderModelValidateApi, - '/workspaces/current/model-providers//models/credentials/validate') +api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers//models") +api.add_resource( + ModelProviderModelEnableApi, + "/workspaces/current/model-providers//models/enable", + endpoint="model-provider-model-enable", +) +api.add_resource( + ModelProviderModelDisableApi, + "/workspaces/current/model-providers//models/disable", + endpoint="model-provider-model-disable", +) +api.add_resource( + ModelProviderModelCredentialApi, "/workspaces/current/model-providers//models/credentials" +) +api.add_resource( + ModelProviderModelValidateApi, "/workspaces/current/model-providers//models/credentials/validate" +) -api.add_resource(ModelProviderModelParameterRuleApi, - '/workspaces/current/model-providers//models/parameter-rules') -api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/') -api.add_resource(DefaultModelApi, '/workspaces/current/default-model') +api.add_resource( + ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers//models/parameter-rules" +) +api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/") +api.add_resource(DefaultModelApi, "/workspaces/current/default-model") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index bafeabb08a..c41a898fdc 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -28,10 +28,18 @@ class ToolProviderListApi(Resource): tenant_id = current_user.current_tenant_id req = reqparse.RequestParser() - req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args') + req.add_argument( + "type", + type=str, + choices=["builtin", "model", "api", "workflow"], + required=False, + nullable=True, + location="args", + ) args = req.parse_args() - return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None)) + return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) + class ToolBuiltinProviderListToolsApi(Resource): @setup_required @@ -41,11 +49,14 @@ class ToolBuiltinProviderListToolsApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools( - user_id, - tenant_id, - provider, - )) + return jsonable_encoder( + BuiltinToolManageService.list_builtin_tool_provider_tools( + user_id, + tenant_id, + provider, + ) + ) + class ToolBuiltinProviderDeleteApi(Resource): @setup_required @@ -54,7 +65,7 @@ class ToolBuiltinProviderDeleteApi(Resource): def post(self, provider): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id @@ -63,7 +74,8 @@ class ToolBuiltinProviderDeleteApi(Resource): tenant_id, provider, ) - + + class ToolBuiltinProviderUpdateApi(Resource): @setup_required @login_required @@ -71,12 +83,12 @@ class ToolBuiltinProviderUpdateApi(Resource): def post(self, provider): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() @@ -84,9 +96,10 @@ class ToolBuiltinProviderUpdateApi(Resource): user_id, tenant_id, provider, - args['credentials'], + args["credentials"], ) - + + class ToolBuiltinProviderGetCredentialsApi(Resource): @setup_required @login_required @@ -101,6 +114,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): provider, ) + class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): @@ -108,6 +122,7 @@ class ToolBuiltinProviderIconApi(Resource): icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) + class ToolApiProviderAddApi(Resource): @setup_required @login_required @@ -115,35 +130,36 @@ class ToolApiProviderAddApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json') - parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[]) - parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") + parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) + parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") args = parser.parse_args() return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, - args['provider'], - args['icon'], - args['credentials'], - args['schema_type'], - args['schema'], - args.get('privacy_policy', ''), - args.get('custom_disclaimer', ''), - args.get('labels', []), + args["provider"], + args["icon"], + args["credentials"], + args["schema_type"], + args["schema"], + args.get("privacy_policy", ""), + args.get("custom_disclaimer", ""), + args.get("labels", []), ) + class ToolApiProviderGetRemoteSchemaApi(Resource): @setup_required @login_required @@ -151,16 +167,17 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument('url', type=str, required=True, nullable=False, location='args') + parser.add_argument("url", type=str, required=True, nullable=False, location="args") args = parser.parse_args() return ApiToolManageService.get_api_tool_provider_remote_schema( current_user.id, current_user.current_tenant_id, - args['url'], + args["url"], ) - + + class ToolApiProviderListToolsApi(Resource): @setup_required @login_required @@ -171,15 +188,18 @@ class ToolApiProviderListToolsApi(Resource): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools( - user_id, - tenant_id, - args['provider'], - )) + return jsonable_encoder( + ApiToolManageService.list_api_tool_provider_tools( + user_id, + tenant_id, + args["provider"], + ) + ) + class ToolApiProviderUpdateApi(Resource): @setup_required @@ -188,37 +208,38 @@ class ToolApiProviderUpdateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json') - parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') - parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") + parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") + parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") args = parser.parse_args() return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, - args['provider'], - args['original_provider'], - args['icon'], - args['credentials'], - args['schema_type'], - args['schema'], - args['privacy_policy'], - args['custom_disclaimer'], - args.get('labels', []), + args["provider"], + args["original_provider"], + args["icon"], + args["credentials"], + args["schema_type"], + args["schema"], + args["privacy_policy"], + args["custom_disclaimer"], + args.get("labels", []), ) + class ToolApiProviderDeleteApi(Resource): @setup_required @login_required @@ -226,22 +247,23 @@ class ToolApiProviderDeleteApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, - args['provider'], + args["provider"], ) + class ToolApiProviderGetApi(Resource): @setup_required @login_required @@ -252,16 +274,17 @@ class ToolApiProviderGetApi(Resource): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") args = parser.parse_args() return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, - args['provider'], + args["provider"], ) + class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @@ -269,6 +292,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): def get(self, provider): return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider) + class ToolApiProviderSchemaApi(Resource): @setup_required @login_required @@ -276,14 +300,15 @@ class ToolApiProviderSchemaApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.parser_api_schema( - schema=args['schema'], + schema=args["schema"], ) + class ToolApiProviderPreviousTestApi(Resource): @setup_required @login_required @@ -291,25 +316,26 @@ class ToolApiProviderPreviousTestApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.test_api_tool_preview( current_user.current_tenant_id, - args['provider_name'] if args['provider_name'] else '', - args['tool_name'], - args['credentials'], - args['parameters'], - args['schema_type'], - args['schema'], + args["provider_name"] if args["provider_name"] else "", + args["tool_name"], + args["credentials"], + args["parameters"], + args["schema_type"], + args["schema"], ) + class ToolWorkflowProviderCreateApi(Resource): @setup_required @login_required @@ -317,35 +343,36 @@ class ToolWorkflowProviderCreateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json') - reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') - reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') - reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') - reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') + reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") + reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") + reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") + reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") + reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") args = reqparser.parse_args() return WorkflowToolManageService.create_workflow_tool( user_id, tenant_id, - args['workflow_app_id'], - args['name'], - args['label'], - args['icon'], - args['description'], - args['parameters'], - args['privacy_policy'], - args.get('labels', []), + args["workflow_app_id"], + args["name"], + args["label"], + args["icon"], + args["description"], + args["parameters"], + args["privacy_policy"], + args.get("labels", []), ) + class ToolWorkflowProviderUpdateApi(Resource): @setup_required @login_required @@ -353,38 +380,39 @@ class ToolWorkflowProviderUpdateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') - reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') - reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') - reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') - reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') - + reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") + reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") + reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") + reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") + reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") + args = reqparser.parse_args() - if not args['workflow_tool_id']: - raise ValueError('incorrect workflow_tool_id') - + if not args["workflow_tool_id"]: + raise ValueError("incorrect workflow_tool_id") + return WorkflowToolManageService.update_workflow_tool( user_id, tenant_id, - args['workflow_tool_id'], - args['name'], - args['label'], - args['icon'], - args['description'], - args['parameters'], - args['privacy_policy'], - args.get('labels', []), + args["workflow_tool_id"], + args["name"], + args["label"], + args["icon"], + args["description"], + args["parameters"], + args["privacy_policy"], + args.get("labels", []), ) + class ToolWorkflowProviderDeleteApi(Resource): @setup_required @login_required @@ -392,21 +420,22 @@ class ToolWorkflowProviderDeleteApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') + reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") args = reqparser.parse_args() return WorkflowToolManageService.delete_workflow_tool( user_id, tenant_id, - args['workflow_tool_id'], + args["workflow_tool_id"], ) - + + class ToolWorkflowProviderGetApi(Resource): @setup_required @login_required @@ -416,28 +445,29 @@ class ToolWorkflowProviderGetApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args') - parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args') + parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") + parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") args = parser.parse_args() - if args.get('workflow_tool_id'): + if args.get("workflow_tool_id"): tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( user_id, tenant_id, - args['workflow_tool_id'], + args["workflow_tool_id"], ) - elif args.get('workflow_app_id'): + elif args.get("workflow_app_id"): tool = WorkflowToolManageService.get_workflow_tool_by_app_id( user_id, tenant_id, - args['workflow_app_id'], + args["workflow_app_id"], ) else: - raise ValueError('incorrect workflow_tool_id or workflow_app_id') + raise ValueError("incorrect workflow_tool_id or workflow_app_id") return jsonable_encoder(tool) - + + class ToolWorkflowProviderListToolApi(Resource): @setup_required @login_required @@ -447,15 +477,18 @@ class ToolWorkflowProviderListToolApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args') + parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") args = parser.parse_args() - return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools( - user_id, - tenant_id, - args['workflow_tool_id'], - )) + return jsonable_encoder( + WorkflowToolManageService.list_single_workflow_tools( + user_id, + tenant_id, + args["workflow_tool_id"], + ) + ) + class ToolBuiltinListApi(Resource): @setup_required @@ -465,11 +498,17 @@ class ToolBuiltinListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in BuiltinToolManageService.list_builtin_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolApiListApi(Resource): @setup_required @login_required @@ -478,11 +517,17 @@ class ToolApiListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in ApiToolManageService.list_api_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolWorkflowListApi(Resource): @setup_required @login_required @@ -491,11 +536,17 @@ class ToolWorkflowListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in WorkflowToolManageService.list_tenant_workflow_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolLabelsApi(Resource): @setup_required @login_required @@ -503,36 +554,41 @@ class ToolLabelsApi(Resource): def get(self): return jsonable_encoder(ToolLabelsService.list_tool_labels()) + # tool provider -api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers') +api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") # builtin tool provider -api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin//tools') -api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin//delete') -api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin//update') -api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin//credentials') -api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin//credentials_schema') -api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin//icon') +api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") +api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") +api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") +api.add_resource( + ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" +) +api.add_resource( + ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin//credentials_schema" +) +api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") # api tool provider -api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') -api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') -api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') -api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update') -api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete') -api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get') -api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema') -api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre') +api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") +api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote") +api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools") +api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update") +api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete") +api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get") +api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema") +api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre") # workflow tool provider -api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create') -api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update') -api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete') -api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get') -api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools') +api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create") +api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update") +api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete") +api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get") +api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools") -api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin') -api.add_resource(ToolApiListApi, '/workspaces/current/tools/api') -api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow') +api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin") +api.add_resource(ToolApiListApi, "/workspaces/current/tools/api") +api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow") -api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels') \ No newline at end of file +api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels") diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 7a11a45ae8..623f0b8b74 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -26,39 +26,34 @@ from services.file_service import FileService from services.workspace_service import WorkspaceService provider_fields = { - 'provider_name': fields.String, - 'provider_type': fields.String, - 'is_valid': fields.Boolean, - 'token_is_set': fields.Boolean, + "provider_name": fields.String, + "provider_type": fields.String, + "is_valid": fields.Boolean, + "token_is_set": fields.Boolean, } tenant_fields = { - 'id': fields.String, - 'name': fields.String, - 'plan': fields.String, - 'status': fields.String, - 'created_at': TimestampField, - 'role': fields.String, - 'in_trial': fields.Boolean, - 'trial_end_reason': fields.String, - 'custom_config': fields.Raw(attribute='custom_config'), + "id": fields.String, + "name": fields.String, + "plan": fields.String, + "status": fields.String, + "created_at": TimestampField, + "role": fields.String, + "in_trial": fields.Boolean, + "trial_end_reason": fields.String, + "custom_config": fields.Raw(attribute="custom_config"), } tenants_fields = { - 'id': fields.String, - 'name': fields.String, - 'plan': fields.String, - 'status': fields.String, - 'created_at': TimestampField, - 'current': fields.Boolean + "id": fields.String, + "name": fields.String, + "plan": fields.String, + "status": fields.String, + "created_at": TimestampField, + "current": fields.Boolean, } -workspace_fields = { - 'id': fields.String, - 'name': fields.String, - 'status': fields.String, - 'created_at': TimestampField -} +workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField} class TenantListApi(Resource): @@ -71,7 +66,7 @@ class TenantListApi(Resource): for tenant in tenants: if tenant.id == current_user.current_tenant_id: tenant.current = True # Set current=True for current tenant - return {'workspaces': marshal(tenants, tenants_fields)}, 200 + return {"workspaces": marshal(tenants, tenants_fields)}, 200 class WorkspaceListApi(Resource): @@ -79,31 +74,37 @@ class WorkspaceListApi(Resource): @admin_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\ - .paginate(page=args['page'], per_page=args['limit']) + tenants = ( + db.session.query(Tenant) + .order_by(Tenant.created_at.desc()) + .paginate(page=args["page"], per_page=args["limit"]) + ) has_more = False - if len(tenants.items) == args['limit']: + if len(tenants.items) == args["limit"]: current_page_first_tenant = tenants[-1] - rest_count = db.session.query(Tenant).filter( - Tenant.created_at < current_page_first_tenant.created_at, - Tenant.id != current_page_first_tenant.id - ).count() + rest_count = ( + db.session.query(Tenant) + .filter( + Tenant.created_at < current_page_first_tenant.created_at, Tenant.id != current_page_first_tenant.id + ) + .count() + ) if rest_count > 0: has_more = True total = db.session.query(Tenant).count() return { - 'data': marshal(tenants.items, workspace_fields), - 'has_more': has_more, - 'limit': args['limit'], - 'page': args['page'], - 'total': total - }, 200 + "data": marshal(tenants.items, workspace_fields), + "has_more": has_more, + "limit": args["limit"], + "page": args["page"], + "total": total, + }, 200 class TenantApi(Resource): @@ -112,8 +113,8 @@ class TenantApi(Resource): @account_initialization_required @marshal_with(tenant_fields) def get(self): - if request.path == '/info': - logging.warning('Deprecated URL /info was used.') + if request.path == "/info": + logging.warning("Deprecated URL /info was used.") tenant = current_user.current_tenant @@ -125,7 +126,7 @@ class TenantApi(Resource): tenant = tenants[0] # else, raise Unauthorized else: - raise Unauthorized('workspace is archived') + raise Unauthorized("workspace is archived") return WorkspaceService.get_tenant_info(tenant), 200 @@ -136,62 +137,64 @@ class SwitchWorkspaceApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('tenant_id', type=str, required=True, location='json') + parser.add_argument("tenant_id", type=str, required=True, location="json") args = parser.parse_args() # check if tenant_id is valid, 403 if not try: - TenantService.switch_tenant(current_user, args['tenant_id']) + TenantService.switch_tenant(current_user, args["tenant_id"]) except Exception: raise AccountNotLinkTenantError("Account not link tenant") - new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant + new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant + + return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} - return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} - class CustomConfigWorkspaceApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('workspace_custom') + @cloud_edition_billing_resource_check("workspace_custom") def post(self): parser = reqparse.RequestParser() - parser.add_argument('remove_webapp_brand', type=bool, location='json') - parser.add_argument('replace_webapp_logo', type=str, location='json') + parser.add_argument("remove_webapp_brand", type=bool, location="json") + parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() custom_config_dict = { - 'remove_webapp_brand': args['remove_webapp_brand'], - 'replace_webapp_logo': args['replace_webapp_logo'] if args['replace_webapp_logo'] is not None else tenant.custom_config_dict.get('replace_webapp_logo') , + "remove_webapp_brand": args["remove_webapp_brand"], + "replace_webapp_logo": args["replace_webapp_logo"] + if args["replace_webapp_logo"] is not None + else tenant.custom_config_dict.get("replace_webapp_logo"), } tenant.custom_config_dict = custom_config_dict db.session.commit() - return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} - + return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} + class WebappLogoWorkspaceApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('workspace_custom') + @cloud_edition_billing_resource_check("workspace_custom") def post(self): # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() - extension = file.filename.split('.')[-1] - if extension.lower() not in ['svg', 'png']: + extension = file.filename.split(".")[-1] + if extension.lower() not in ["svg", "png"]: raise UnsupportedFileTypeError() try: @@ -201,14 +204,14 @@ class WebappLogoWorkspaceApi(Resource): raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - - return { 'id': upload_file.id }, 201 + + return {"id": upload_file.id}, 201 -api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants -api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants -api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info -api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated -api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant -api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config') -api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload') +api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants +api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants +api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info +api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated +api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant +api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config") +api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 3baf69acfd..5a964c84fa 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -16,7 +16,7 @@ def account_initialization_required(view): # check account initialization account = current_user - if account.status == 'uninitialized': + if account.status == "uninitialized": raise AccountNotInitializedError() return view(*args, **kwargs) @@ -27,7 +27,7 @@ def account_initialization_required(view): def only_edition_cloud(view): @wraps(view) def decorated(*args, **kwargs): - if dify_config.EDITION != 'CLOUD': + if dify_config.EDITION != "CLOUD": abort(404) return view(*args, **kwargs) @@ -38,7 +38,7 @@ def only_edition_cloud(view): def only_edition_self_hosted(view): @wraps(view) def decorated(*args, **kwargs): - if dify_config.EDITION != 'SELF_HOSTED': + if dify_config.EDITION != "SELF_HOSTED": abort(404) return view(*args, **kwargs) @@ -46,8 +46,9 @@ def only_edition_self_hosted(view): return decorated -def cloud_edition_billing_resource_check(resource: str, - error_msg: str = "You have reached the limit of your subscription."): +def cloud_edition_billing_resource_check( + resource: str, error_msg: str = "You have reached the limit of your subscription." +): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): @@ -58,22 +59,22 @@ def cloud_edition_billing_resource_check(resource: str, vector_space = features.vector_space documents_upload_quota = features.documents_upload_quota annotation_quota_limit = features.annotation_quota_limit - if resource == 'members' and 0 < members.limit <= members.size: + if resource == "members" and 0 < members.limit <= members.size: abort(403, error_msg) - elif resource == 'apps' and 0 < apps.limit <= apps.size: + elif resource == "apps" and 0 < apps.limit <= apps.size: abort(403, error_msg) - elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: + elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: abort(403, error_msg) - elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: + elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: # The api of file upload is used in the multiple places, so we need to check the source of the request from datasets - source = request.args.get('source') - if source == 'datasets': + source = request.args.get("source") + if source == "datasets": abort(403, error_msg) else: return view(*args, **kwargs) - elif resource == 'workspace_custom' and not features.can_replace_logo: + elif resource == "workspace_custom" and not features.can_replace_logo: abort(403, error_msg) - elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: + elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: abort(403, error_msg) else: return view(*args, **kwargs) @@ -85,15 +86,17 @@ def cloud_edition_billing_resource_check(resource: str, return interceptor -def cloud_edition_billing_knowledge_limit_check(resource: str, - error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): +def cloud_edition_billing_knowledge_limit_check( + resource: str, + error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.", +): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if resource == 'add_segment': - if features.billing.subscription.plan == 'sandbox': + if resource == "add_segment": + if features.billing.subscription.plan == "sandbox": abort(403, error_msg) else: return view(*args, **kwargs) @@ -112,7 +115,7 @@ def cloud_utm_record(view): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - utm_info = request.cookies.get('utm_info') + utm_info = request.cookies.get("utm_info") if utm_info: utm_info = json.loads(utm_info) diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index 8d38ab9866..97d5c3f88f 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -2,7 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('files', __name__) +bp = Blueprint("files", __name__) api = ExternalApi(bp) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 247b5d45e1..2432285d93 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -13,35 +13,30 @@ class ImagePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - timestamp = request.args.get('timestamp') - nonce = request.args.get('nonce') - sign = request.args.get('sign') + timestamp = request.args.get("timestamp") + nonce = request.args.get("nonce") + sign = request.args.get("sign") if not timestamp or not nonce or not sign: - return {'content': 'Invalid request.'}, 400 + return {"content": "Invalid request."}, 400 try: - generator, mimetype = FileService.get_image_preview( - file_id, - timestamp, - nonce, - sign - ) + generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() return Response(generator, mimetype=mimetype) - + class WorkspaceWebappLogoApi(Resource): def get(self, workspace_id): workspace_id = str(workspace_id) custom_config = TenantService.get_custom_config(workspace_id) - webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None + webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None if not webapp_logo_file_id: - raise NotFound('webapp logo is not found') + raise NotFound("webapp logo is not found") try: generator, mimetype = FileService.get_public_image_preview( @@ -53,11 +48,11 @@ class WorkspaceWebappLogoApi(Resource): return Response(generator, mimetype=mimetype) -api.add_resource(ImagePreviewApi, '/files//image-preview') -api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces//webapp-logo') +api.add_resource(ImagePreviewApi, "/files//image-preview") +api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces//webapp-logo") class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 5a07ad2ea5..38ac0815da 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -13,36 +13,39 @@ class ToolFilePreviewApi(Resource): parser = reqparse.RequestParser() - parser.add_argument('timestamp', type=str, required=True, location='args') - parser.add_argument('nonce', type=str, required=True, location='args') - parser.add_argument('sign', type=str, required=True, location='args') + parser.add_argument("timestamp", type=str, required=True, location="args") + parser.add_argument("nonce", type=str, required=True, location="args") + parser.add_argument("sign", type=str, required=True, location="args") args = parser.parse_args() - if not ToolFileManager.verify_file(file_id=file_id, - timestamp=args['timestamp'], - nonce=args['nonce'], - sign=args['sign'], + if not ToolFileManager.verify_file( + file_id=file_id, + timestamp=args["timestamp"], + nonce=args["nonce"], + sign=args["sign"], ): - raise Forbidden('Invalid request.') - + raise Forbidden("Invalid request.") + try: result = ToolFileManager.get_file_generator_by_tool_file_id( file_id, ) if not result: - raise NotFound('file is not found') - + raise NotFound("file is not found") + generator, mimetype = result except Exception: raise UnsupportedFileTypeError() return Response(generator, mimetype=mimetype) -api.add_resource(ToolFilePreviewApi, '/files/tools/.') + +api.add_resource(ToolFilePreviewApi, "/files/tools/.") + class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index ad49a649ca..9f124736a9 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -2,8 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('inner_api', __name__, url_prefix='/inner/api') +bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") api = ExternalApi(bp) from .workspace import workspace - diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 06610d8933..914b60f263 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -9,29 +9,24 @@ from services.account_service import TenantService class EnterpriseWorkspace(Resource): - @setup_required @inner_api_only def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('owner_email', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("owner_email", type=str, required=True, location="json") args = parser.parse_args() - account = Account.query.filter_by(email=args['owner_email']).first() + account = Account.query.filter_by(email=args["owner_email"]).first() if account is None: - return { - 'message': 'owner account not found.' - }, 404 + return {"message": "owner account not found."}, 404 - tenant = TenantService.create_tenant(args['name']) - TenantService.create_tenant_member(tenant, account, role='owner') + tenant = TenantService.create_tenant(args["name"]) + TenantService.create_tenant_member(tenant, account, role="owner") tenant_was_created.send(tenant) - return { - 'message': 'enterprise workspace created.' - } + return {"message": "enterprise workspace created."} -api.add_resource(EnterpriseWorkspace, '/enterprise/workspace') +api.add_resource(EnterpriseWorkspace, "/enterprise/workspace") diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 5c37f5276f..51ffe683ff 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -17,7 +17,7 @@ def inner_api_only(view): abort(404) # get header 'X-Inner-Api-Key' - inner_api_key = request.headers.get('X-Inner-Api-Key') + inner_api_key = request.headers.get("X-Inner-Api-Key") if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: abort(401) @@ -33,29 +33,29 @@ def inner_api_user_auth(view): return view(*args, **kwargs) # get header 'X-Inner-Api-Key' - authorization = request.headers.get('Authorization') + authorization = request.headers.get("Authorization") if not authorization: return view(*args, **kwargs) - parts = authorization.split(':') + parts = authorization.split(":") if len(parts) != 2: return view(*args, **kwargs) user_id, token = parts - if ' ' in user_id: - user_id = user_id.split(' ')[1] + if " " in user_id: + user_id = user_id.split(" ")[1] - inner_api_key = request.headers.get('X-Inner-Api-Key') + inner_api_key = request.headers.get("X-Inner-Api-Key") - data_to_sign = f'DIFY {user_id}' + data_to_sign = f"DIFY {user_id}" - signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1) - signature = b64encode(signature.digest()).decode('utf-8') + signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) + signature = b64encode(signature.digest()).decode("utf-8") if signature != token: return view(*args, **kwargs) - kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first() + kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() return view(*args, **kwargs) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 082660a891..ad39c160ac 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -2,7 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('service_api', __name__, url_prefix='/v1') +bp = Blueprint("service_api", __name__, url_prefix="/v1") api = ExternalApi(bp) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 3b3cf1b026..ecc2d73deb 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,4 +1,3 @@ - from flask_restful import Resource, fields, marshal_with from configs import dify_config @@ -13,32 +12,30 @@ class AppParameterApi(Resource): """Resource for app variables.""" variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) + "key": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "default": fields.String, + "max_length": fields.Integer, + "options": fields.List(fields.String), } - system_parameters_fields = { - 'image_file_size_limit': fields.String - } + system_parameters_fields = {"image_file_size_limit": fields.String} parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(system_parameters_fields), } @validate_app_token @@ -56,30 +53,35 @@ class AppParameterApi(Resource): app_model_config = app_model.app_model_config features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get( + "suggested_questions_after_answer", {"enabled": False} + ), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, } @@ -89,16 +91,14 @@ class AppMetaApi(Resource): """Get app meta""" return AppService().get_app_meta(app_model) + class AppInfoApi(Resource): @validate_app_token def get(self, app_model: App): - """Get app infomation""" - return { - 'name':app_model.name, - 'description':app_model.description - } + """Get app information""" + return {"name": app_model.name, "description": app_model.description} -api.add_resource(AppParameterApi, '/parameters') -api.add_resource(AppMetaApi, '/meta') -api.add_resource(AppInfoApi, '/info') +api.add_resource(AppParameterApi, "/parameters") +api.add_resource(AppMetaApi, "/meta") +api.add_resource(AppInfoApi, "/info") diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 3c009af343..85aab047a7 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -33,14 +33,10 @@ from services.errors.audio import ( class AudioApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) def post(self, app_model: App, end_user: EndUser): - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=end_user - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -74,30 +70,32 @@ class TextApi(Resource): def post(self, app_model: App, end_user: EndUser): try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - end_user=end_user.external_user_id, - voice=voice, - text=text + app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text ) return response @@ -127,5 +125,5 @@ class TextApi(Resource): raise InternalServerError() -api.add_resource(AudioApi, '/audio-to-text') -api.add_resource(TextApi, '/text-to-audio') +api.add_resource(AudioApi, "/audio-to-text") +api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 2511f46baf..f1771baf31 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -33,21 +33,21 @@ from services.app_generate_service import AppGenerateService class CompletionApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise AppUnavailableError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" - args['auto_generate_name'] = False + args["auto_generate_name"] = False try: response = AppGenerateService.generate( @@ -84,12 +84,12 @@ class CompletionApi(Resource): class CompletionStopApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise AppUnavailableError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(Resource): @@ -100,25 +100,21 @@ class ChatApi(Resource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') - parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") + parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.SERVICE_API, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming ) return helper.compact_generate_response(response) @@ -153,10 +149,10 @@ class ChatStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/completion-messages') -api.add_resource(CompletionStopApi, '/completion-messages//stop') -api.add_resource(ChatApi, '/chat-messages') -api.add_resource(ChatStopApi, '/chat-messages//stop') +api.add_resource(CompletionApi, "/completion-messages") +api.add_resource(CompletionStopApi, "/completion-messages//stop") +api.add_resource(ChatApi, "/chat-messages") +api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 4598a14611..734027a1c5 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -14,7 +14,6 @@ from services.conversation_service import ConversationService class ConversationApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): @@ -23,20 +22,26 @@ class ConversationApi(Resource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'], - required=False, default='-updated_at', location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() try: return ConversationService.pagination_by_last_id( app_model=app_model, user=end_user, - last_id=args['last_id'], - limit=args['limit'], + last_id=args["last_id"], + limit=args["limit"], invoke_from=InvokeFrom.SERVICE_API, - sort_by=args['sort_by'] + sort_by=args["sort_by"], ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -56,11 +61,10 @@ class ConversationDetailApi(Resource): ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ConversationRenameApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @marshal_with(simple_conversation_fields) def post(self, app_model: App, end_user: EndUser, c_id): @@ -71,22 +75,16 @@ class ConversationRenameApi(Resource): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: - return ConversationService.rename( - app_model, - conversation_id, - end_user, - args['name'], - args['auto_generate'] - ) + return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") -api.add_resource(ConversationRenameApi, '/conversations//name', endpoint='conversation_name') -api.add_resource(ConversationApi, '/conversations') -api.add_resource(ConversationDetailApi, '/conversations/', endpoint='conversation_detail') +api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="conversation_name") +api.add_resource(ConversationApi, "/conversations") +api.add_resource(ConversationDetailApi, "/conversations/", endpoint="conversation_detail") diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index ac9edb1b4f..ca91da80c1 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -2,104 +2,108 @@ from libs.exception import BaseHTTPException class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Please check if your Completion app mode matches the right API route." code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "Please check if your app mode matches the right API route." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Please check if your app mode matches the right API route." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted OpenAI has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 5dbc1b1d1b..e0a772eb31 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -16,15 +16,13 @@ from services.file_service import FileService class FileApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) @marshal_with(file_fields) def post(self, app_model: App, end_user: EndUser): - - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if not file.mimetype: @@ -43,4 +41,4 @@ class FileApi(Resource): return upload_file, 201 -api.add_resource(FileApi, '/files/upload') +api.add_resource(FileApi, "/files/upload") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 875870e667..b39aaf7dd8 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -17,61 +17,59 @@ from services.message_service import MessageService class MessageListApi(Resource): - feedback_fields = { - 'rating': fields.String - } + feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_labels': fields.Raw, - 'tool_input': fields.String, - 'created_at': TimestampField, - 'observation': fields.String, - 'message_files': fields.List(fields.String, attribute='files') + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "message_files": fields.List(fields.String, attribute="files"), } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @@ -82,14 +80,15 @@ class MessageListApi(Resource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, end_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: @@ -102,15 +101,15 @@ class MessageFeedbackApi(Resource): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args['rating']) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageSuggestedApi(Resource): @@ -123,10 +122,7 @@ class MessageSuggestedApi(Resource): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=end_user, - message_id=message_id, - invoke_from=InvokeFrom.SERVICE_API + app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -136,9 +132,9 @@ class MessageSuggestedApi(Resource): logging.exception("internal server error.") raise InternalServerError() - return {'result': 'success', 'data': questions} + return {"result": "success", "data": questions} -api.add_resource(MessageListApi, '/messages') -api.add_resource(MessageFeedbackApi, '/messages//feedbacks') -api.add_resource(MessageSuggestedApi, '/messages//suggested') +api.add_resource(MessageListApi, "/messages") +api.add_resource(MessageFeedbackApi, "/messages//feedbacks") +api.add_resource(MessageSuggestedApi, "/messages//suggested") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 9446f9d588..5822e0921b 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -30,19 +30,20 @@ from services.app_generate_service import AppGenerateService logger = logging.getLogger(__name__) workflow_run_fields = { - 'id': fields.String, - 'workflow_id': fields.String, - 'status': fields.String, - 'inputs': fields.Raw, - 'outputs': fields.Raw, - 'error': fields.String, - 'total_steps': fields.Integer, - 'total_tokens': fields.Integer, - 'created_at': fields.DateTime, - 'finished_at': fields.DateTime, - 'elapsed_time': fields.Float, + "id": fields.String, + "workflow_id": fields.String, + "status": fields.String, + "inputs": fields.Raw, + "outputs": fields.Raw, + "error": fields.String, + "total_steps": fields.Integer, + "total_tokens": fields.Integer, + "created_at": fields.DateTime, + "finished_at": fields.DateTime, + "elapsed_time": fields.Float, } + class WorkflowRunDetailApi(Resource): @validate_app_token @marshal_with(workflow_run_fields) @@ -56,6 +57,8 @@ class WorkflowRunDetailApi(Resource): workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first() return workflow_run + + class WorkflowRunApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): @@ -67,20 +70,16 @@ class WorkflowRunApi(Resource): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") args = parser.parse_args() - streaming = args.get('response_mode') == 'streaming' + streaming = args.get("response_mode") == "streaming" try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.SERVICE_API, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming ) return helper.compact_generate_response(response) @@ -111,11 +110,9 @@ class WorkflowTaskStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(WorkflowRunApi, '/workflows/run') -api.add_resource(WorkflowRunDetailApi, '/workflows/run/') -api.add_resource(WorkflowTaskStopApi, '/workflows/tasks//stop') +api.add_resource(WorkflowRunApi, "/workflows/run") +api.add_resource(WorkflowRunDetailApi, "/workflows/run/") +api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index e0863859a2..c2c0672a03 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -16,7 +16,7 @@ from services.dataset_service import DatasetService def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 40 characters.') + raise ValueError("Name must be between 1 to 40 characters.") return name @@ -26,24 +26,18 @@ class DatasetListApi(DatasetApiResource): def get(self, tenant_id): """Resource for getting datasets.""" - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - provider = request.args.get('provider', default="vendor") - search = request.args.get('keyword', default=None, type=str) - tag_ids = request.args.getlist('tag_ids') + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + provider = request.args.get("provider", default="vendor") + search = request.args.get("keyword", default=None, type=str) + tag_ids = request.args.getlist("tag_ids") - datasets, total = DatasetService.get_datasets(page, limit, provider, - tenant_id, current_user, search, tag_ids) + datasets, total = DatasetService.get_datasets(page, limit, provider, tenant_id, current_user, search, tag_ids) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: @@ -51,50 +45,59 @@ class DatasetListApi(DatasetApiResource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item['indexing_technique'] == 'high_quality': + if item["indexing_technique"] == "high_quality": item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: - item['embedding_available'] = True + item["embedding_available"] = True else: - item['embedding_available'] = False + item["embedding_available"] = False else: - item['embedding_available'] = True - response = { - 'data': data, - 'has_more': len(datasets) == limit, - 'limit': limit, - 'total': total, - 'page': page - } + item["embedding_available"] = True + response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 - def post(self, tenant_id): """Resource for creating datasets.""" parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - help='Invalid indexing technique.') - parser.add_argument('permission', type=str, location='json', choices=( - DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.', required=False, nullable=False) + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + help="Invalid indexing technique.", + ) + parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + required=False, + nullable=False, + ) args = parser.parse_args() try: dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, - name=args['name'], - indexing_technique=args['indexing_technique'], + name=args["name"], + indexing_technique=args["indexing_technique"], account=current_user, - permission=args['permission'] + permission=args["permission"], ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() return marshal(dataset, dataset_detail_fields), 200 + class DatasetApi(DatasetApiResource): """Resource for dataset.""" @@ -106,7 +109,7 @@ class DatasetApi(DatasetApiResource): dataset_id (UUID): The ID of the dataset to be deleted. Returns: - dict: A dictionary with a key 'result' and a value 'success' + dict: A dictionary with a key 'result' and a value 'success' if the dataset was successfully deleted. Omitted in HTTP response. int: HTTP status code 204 indicating that the operation was successful. @@ -118,11 +121,12 @@ class DatasetApi(DatasetApiResource): try: if DatasetService.delete_dataset(dataset_id_str, current_user): - return {'result': 'success'}, 204 + return {"result": "success"}, 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: raise DatasetInUseError() -api.add_resource(DatasetListApi, '/datasets') -api.add_resource(DatasetApi, '/datasets/') + +api.add_resource(DatasetListApi, "/datasets") +api.add_resource(DatasetApi, "/datasets/") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index ac1ea820a6..fb48a6c76c 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -27,47 +27,40 @@ from services.file_service import FileService class DocumentAddByTextApi(DatasetApiResource): """Resource for documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_resource_check('documents', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_resource_check("documents", "dataset") def post(self, tenant_id, dataset_id): """Create document by text.""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') - parser.add_argument('text', type=str, required=True, nullable=False, location='json') - parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') - parser.add_argument('original_document_id', type=str, required=False, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("text", type=str, required=True, nullable=False, location="json") + parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") + parser.add_argument("original_document_id", type=str, required=False, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") - if not dataset.indexing_technique and not args['indexing_technique']: - raise ValueError('indexing_technique is required.') + if not dataset.indexing_technique and not args["indexing_technique"]: + raise ValueError("indexing_technique is required.") - upload_file = FileService.upload_text(args.get('text'), args.get('name')) + upload_file = FileService.upload_text(args.get("text"), args.get("name")) data_source = { - 'type': 'upload_file', - 'info_list': { - 'data_source_type': 'upload_file', - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } - args['data_source'] = data_source + args["data_source"] = data_source # validate args DocumentService.document_create_args_validate(args) @@ -76,60 +69,49 @@ class DocumentAddByTextApi(DatasetApiResource): dataset=dataset, document_data=args, account=current_user, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by text.""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, nullable=True, location='json') - parser.add_argument('text', type=str, required=False, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + parser.add_argument("text", type=str, required=False, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") - if args['text']: - upload_file = FileService.upload_text(args.get('text'), args.get('name')) + if args["text"]: + upload_file = FileService.upload_text(args.get("text"), args.get("name")) data_source = { - 'type': 'upload_file', - 'info_list': { - 'data_source_type': 'upload_file', - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } - args['data_source'] = data_source + args["data_source"] = data_source # validate args - args['original_document_id'] = str(document_id) + args["original_document_id"] = str(document_id) DocumentService.document_create_args_validate(args) try: @@ -137,65 +119,53 @@ class DocumentUpdateByTextApi(DatasetApiResource): dataset=dataset, document_data=args, account=current_user, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentAddByFileApi(DatasetApiResource): """Resource for documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_resource_check('documents', 'dataset') + + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_resource_check("documents", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" args = {} - if 'data' in request.form: - args = json.loads(request.form['data']) - if 'doc_form' not in args: - args['doc_form'] = 'text_model' - if 'doc_language' not in args: - args['doc_language'] = 'English' + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') - if not dataset.indexing_technique and not args.get('indexing_technique'): - raise ValueError('indexing_technique is required.') + raise ValueError("Dataset is not exist.") + if not dataset.indexing_technique and not args.get("indexing_technique"): + raise ValueError("indexing_technique is required.") # save file info - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() upload_file = FileService.upload_file(file, current_user) - data_source = { - 'type': 'upload_file', - 'info_list': { - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } - } - args['data_source'] = data_source + data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} + args["data_source"] = data_source # validate args DocumentService.document_create_args_validate(args) @@ -204,63 +174,49 @@ class DocumentAddByFileApi(DatasetApiResource): dataset=dataset, document_data=args, account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentUpdateByFileApi(DatasetApiResource): """Resource for update documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" args = {} - if 'data' in request.form: - args = json.loads(request.form['data']) - if 'doc_form' not in args: - args['doc_form'] = 'text_model' - if 'doc_language' not in args: - args['doc_language'] = 'English' + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') - if 'file' in request.files: + raise ValueError("Dataset is not exist.") + if "file" in request.files: # save file info - file = request.files['file'] - + file = request.files["file"] if len(request.files) > 1: raise TooManyFilesError() upload_file = FileService.upload_file(file, current_user) - data_source = { - 'type': 'upload_file', - 'info_list': { - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } - } - args['data_source'] = data_source + data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} + args["data_source"] = data_source # validate args - args['original_document_id'] = str(document_id) + args["original_document_id"] = str(document_id) DocumentService.document_create_args_validate(args) try: @@ -268,16 +224,13 @@ class DocumentUpdateByFileApi(DatasetApiResource): dataset=dataset, document_data=args, account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 @@ -289,13 +242,10 @@ class DocumentDeleteApi(DatasetApiResource): tenant_id = str(tenant_id) # get dataset info - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") document = DocumentService.get_document(dataset.id, document_id) @@ -311,44 +261,39 @@ class DocumentDeleteApi(DatasetApiResource): # delete document DocumentService.delete_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DocumentListApi(DatasetApiResource): def get(self, tenant_id, dataset_id): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - search = request.args.get('keyword', default=None, type=str) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + search = request.args.get("keyword", default=None, type=str) + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") - query = Document.query.filter_by( - dataset_id=str(dataset_id), tenant_id=tenant_id) + query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) if search: - search = f'%{search}%' + search = f"%{search}%" query = query.filter(Document.name.like(search)) query = query.order_by(desc(Document.created_at)) - paginated_documents = query.paginate( - page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items response = { - 'data': marshal(documents, document_fields), - 'has_more': len(documents) == limit, - 'limit': limit, - 'total': paginated_documents.total, - 'page': page + "data": marshal(documents, document_fields), + "has_more": len(documents) == limit, + "limit": limit, + "total": paginated_documents.total, + "page": page, } return response @@ -360,38 +305,36 @@ class DocumentIndexingStatusApi(DatasetApiResource): batch = str(batch) tenant_id = str(tenant_id) # get dataset - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # get documents documents = DocumentService.get_batch_documents(dataset_id, batch) if not documents: - raise NotFound('Documents not found.') + raise NotFound("Documents not found.") documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data -api.add_resource(DocumentAddByTextApi, '/datasets//document/create_by_text') -api.add_resource(DocumentAddByFileApi, '/datasets//document/create_by_file') -api.add_resource(DocumentUpdateByTextApi, '/datasets//documents//update_by_text') -api.add_resource(DocumentUpdateByFileApi, '/datasets//documents//update_by_file') -api.add_resource(DocumentDeleteApi, '/datasets//documents/') -api.add_resource(DocumentListApi, '/datasets//documents') -api.add_resource(DocumentIndexingStatusApi, '/datasets//documents//indexing-status') +api.add_resource(DocumentAddByTextApi, "/datasets//document/create_by_text") +api.add_resource(DocumentAddByFileApi, "/datasets//document/create_by_file") +api.add_resource(DocumentUpdateByTextApi, "/datasets//documents//update_by_text") +api.add_resource(DocumentUpdateByFileApi, "/datasets//documents//update_by_file") +api.add_resource(DocumentDeleteApi, "/datasets//documents/") +api.add_resource(DocumentListApi, "/datasets//documents") +api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py index e77693b6c9..5ff5e08c72 100644 --- a/api/controllers/service_api/dataset/error.py +++ b/api/controllers/service_api/dataset/error.py @@ -2,78 +2,78 @@ from libs.exception import BaseHTTPException class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = 'high_quality_dataset_only' + error_code = "high_quality_dataset_only" description = "Current operation only supports 'high-quality' datasets." code = 400 class DatasetNotInitializedError(BaseHTTPException): - error_code = 'dataset_not_initialized' + error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." code = 400 class ArchivedDocumentImmutableError(BaseHTTPException): - error_code = 'archived_document_immutable' + error_code = "archived_document_immutable" description = "The archived document is not editable." code = 403 class DatasetNameDuplicateError(BaseHTTPException): - error_code = 'dataset_name_duplicate' + error_code = "dataset_name_duplicate" description = "The dataset name already exists. Please modify your dataset name." code = 409 class InvalidActionError(BaseHTTPException): - error_code = 'invalid_action' + error_code = "invalid_action" description = "Invalid action." code = 400 class DocumentAlreadyFinishedError(BaseHTTPException): - error_code = 'document_already_finished' + error_code = "document_already_finished" description = "The document has been processed. Please refresh the page or go to the document details." code = 400 class DocumentIndexingError(BaseHTTPException): - error_code = 'document_indexing' + error_code = "document_indexing" description = "The document is being processed and cannot be edited." code = 400 class InvalidMetadataError(BaseHTTPException): - error_code = 'invalid_metadata' + error_code = "invalid_metadata" description = "The metadata content is incorrect. Please check and verify." code = 400 class DatasetInUseError(BaseHTTPException): - error_code = 'dataset_in_use' + error_code = "dataset_in_use" description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." code = 409 diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 0fa2aa65b2..5e10f3b48c 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -21,52 +21,47 @@ from services.dataset_service import DatasetService, DocumentService, SegmentSer class SegmentApi(DatasetApiResource): """Resource for segments.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_knowledge_limit_check('add_segment', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") def post(self, tenant_id, dataset_id, document_id): """Create single segment.""" # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # validate args parser = reqparse.RequestParser() - parser.add_argument('segments', type=list, required=False, nullable=True, location='json') + parser.add_argument("segments", type=list, required=False, nullable=True, location="json") args = parser.parse_args() - if args['segments'] is not None: - for args_item in args['segments']: + if args["segments"] is not None: + for args_item in args["segments"]: SegmentService.segment_create_args_validate(args_item, document) - segments = SegmentService.multi_create_segment(args['segments'], document, dataset) - return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form - }, 200 + segments = SegmentService.multi_create_segment(args["segments"], document, dataset) + return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 else: return {"error": "Segemtns is required"}, 400 @@ -75,61 +70,53 @@ class SegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) parser = reqparse.RequestParser() - parser.add_argument('status', type=str, - action='append', default=[], location='args') - parser.add_argument('keyword', type=str, default=None, location='args') + parser.add_argument("status", type=str, action="append", default=[], location="args") + parser.add_argument("keyword", type=str, default=None, location="args") args = parser.parse_args() - status_list = args['status'] - keyword = args['keyword'] + status_list = args["status"] + keyword = args["keyword"] query = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id ) if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) if keyword: - query = query.where(DocumentSegment.content.ilike(f'%{keyword}%')) + query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) total = query.count() segments = query.order_by(DocumentSegment.position).all() - return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form, - 'total': total - }, 200 + return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form, "total": total}, 200 class DatasetSegmentApi(DatasetApiResource): @@ -137,48 +124,41 @@ class DatasetSegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check segment segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') - if dataset.indexing_technique == 'high_quality': + raise NotFound("Document not found.") + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -186,35 +166,34 @@ class DatasetSegmentApi(DatasetApiResource): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # validate args parser = reqparse.RequestParser() - parser.add_argument('segment', type=dict, required=False, nullable=True, location='json') + parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") args = parser.parse_args() - SegmentService.segment_create_args_validate(args['segment'], document) - segment = SegmentService.update_segment(args['segment'], segment, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + SegmentService.segment_create_args_validate(args["segment"], document) + segment = SegmentService.update_segment(args["segment"], segment, document, dataset) + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 -api.add_resource(SegmentApi, '/datasets//documents//segments') -api.add_resource(DatasetSegmentApi, '/datasets//documents//segments/') +api.add_resource(SegmentApi, "/datasets//documents//segments") +api.add_resource( + DatasetSegmentApi, "/datasets//documents//segments/" +) diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index c910063ebd..d24c4597e2 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -13,4 +13,4 @@ class IndexApi(Resource): } -api.add_resource(IndexApi, '/') +api.add_resource(IndexApi, "/") diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 819512edf0..a596c6f287 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -21,9 +21,10 @@ class WhereisUserArg(Enum): """ Enum for whereis_user_arg. """ - QUERY = 'query' - JSON = 'json' - FORM = 'form' + + QUERY = "query" + JSON = "json" + FORM = "form" class FetchUserArg(BaseModel): @@ -35,13 +36,13 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): - api_token = validate_and_get_api_token('app') + api_token = validate_and_get_api_token("app") app_model = db.session.query(App).filter(App.id == api_token.app_id).first() if not app_model: raise Forbidden("The app no longer exists.") - if app_model.status != 'normal': + if app_model.status != "normal": raise Forbidden("The app's status is abnormal.") if not app_model.enable_api: @@ -51,15 +52,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio if tenant.status == TenantStatus.ARCHIVE: raise Forbidden("The workspace's status is archived.") - kwargs['app_model'] = app_model + kwargs["app_model"] = app_model if fetch_user_arg: if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: - user_id = request.args.get('user') + user_id = request.args.get("user") elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: - user_id = request.get_json().get('user') + user_id = request.get_json().get("user") elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: - user_id = request.form.get('user') + user_id = request.form.get("user") else: # use default-user user_id = None @@ -70,9 +71,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio if user_id: user_id = str(user_id) - kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id) + kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id) return view_func(*args, **kwargs) + return decorated_view if view is None: @@ -81,9 +83,9 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio return decorator(view) -def cloud_edition_billing_resource_check(resource: str, - api_token_type: str, - error_msg: str = "You have reached the limit of your subscription."): +def cloud_edition_billing_resource_check( + resource: str, api_token_type: str, error_msg: str = "You have reached the limit of your subscription." +): def interceptor(view): def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) @@ -95,33 +97,37 @@ def cloud_edition_billing_resource_check(resource: str, vector_space = features.vector_space documents_upload_quota = features.documents_upload_quota - if resource == 'members' and 0 < members.limit <= members.size: + if resource == "members" and 0 < members.limit <= members.size: raise Forbidden(error_msg) - elif resource == 'apps' and 0 < apps.limit <= apps.size: + elif resource == "apps" and 0 < apps.limit <= apps.size: raise Forbidden(error_msg) - elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: + elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: raise Forbidden(error_msg) - elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: + elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: raise Forbidden(error_msg) else: return view(*args, **kwargs) return view(*args, **kwargs) + return decorated + return interceptor -def cloud_edition_billing_knowledge_limit_check(resource: str, - api_token_type: str, - error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): +def cloud_edition_billing_knowledge_limit_check( + resource: str, + api_token_type: str, + error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.", +): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) features = FeatureService.get_features(api_token.tenant_id) if features.billing.enabled: - if resource == 'add_segment': - if features.billing.subscription.plan == 'sandbox': + if resource == "add_segment": + if features.billing.subscription.plan == "sandbox": raise Forbidden(error_msg) else: return view(*args, **kwargs) @@ -132,17 +138,20 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, return interceptor + def validate_dataset_token(view=None): def decorator(view): @wraps(view) def decorated(*args, **kwargs): - api_token = validate_and_get_api_token('dataset') - tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ - .filter(Tenant.id == api_token.tenant_id) \ - .filter(TenantAccountJoin.tenant_id == Tenant.id) \ - .filter(TenantAccountJoin.role.in_(['owner'])) \ - .filter(Tenant.status == TenantStatus.NORMAL) \ - .one_or_none() # TODO: only owner information is required, so only one is returned. + api_token = validate_and_get_api_token("dataset") + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == api_token.tenant_id) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.role.in_(["owner"])) + .filter(Tenant.status == TenantStatus.NORMAL) + .one_or_none() + ) # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join account = Account.query.filter_by(id=ta.account_id).first() @@ -156,6 +165,7 @@ def validate_dataset_token(view=None): else: raise Unauthorized("Tenant does not exist.") return view(api_token.tenant_id, *args, **kwargs) + return decorated if view: @@ -170,20 +180,24 @@ def validate_and_get_api_token(scope=None): """ Validate and get API token. """ - auth_header = request.headers.get('Authorization') - if auth_header is None or ' ' not in auth_header: + auth_header = request.headers.get("Authorization") + if auth_header is None or " " not in auth_header: raise Unauthorized("Authorization header must be provided and start with 'Bearer'") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': + if auth_scheme != "bearer": raise Unauthorized("Authorization scheme must be 'Bearer'") - api_token = db.session.query(ApiToken).filter( - ApiToken.token == auth_token, - ApiToken.type == scope, - ).first() + api_token = ( + db.session.query(ApiToken) + .filter( + ApiToken.token == auth_token, + ApiToken.type == scope, + ) + .first() + ) if not api_token: raise Unauthorized("Access token is invalid") @@ -199,23 +213,26 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] Create or update session terminal based on user ID. """ if not user_id: - user_id = 'DEFAULT-USER' + user_id = "DEFAULT-USER" - end_user = db.session.query(EndUser) \ + end_user = ( + db.session.query(EndUser) .filter( - EndUser.tenant_id == app_model.tenant_id, - EndUser.app_id == app_model.id, - EndUser.session_id == user_id, - EndUser.type == 'service_api' - ).first() + EndUser.tenant_id == app_model.tenant_id, + EndUser.app_id == app_model.id, + EndUser.session_id == user_id, + EndUser.type == "service_api", + ) + .first() + ) if end_user is None: end_user = EndUser( tenant_id=app_model.tenant_id, app_id=app_model.id, - type='service_api', - is_anonymous=True if user_id == 'DEFAULT-USER' else False, - session_id=user_id + type="service_api", + is_anonymous=True if user_id == "DEFAULT-USER" else False, + session_id=user_id, ) db.session.add(end_user) db.session.commit() diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index aa19bdc034..630b9468a7 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -2,7 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('web', __name__, url_prefix='/api') +bp = Blueprint("web", __name__, url_prefix="/api") api = ExternalApi(bp) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index f4db82552c..aabca93338 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -10,33 +10,32 @@ from services.app_service import AppService class AppParameterApi(WebApiResource): """Resource for app variables.""" + variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) + "key": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "default": fields.String, + "max_length": fields.Integer, + "options": fields.List(fields.String), } - system_parameters_fields = { - 'image_file_size_limit': fields.String - } + system_parameters_fields = {"image_file_size_limit": fields.String} parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(system_parameters_fields), } @marshal_with(parameters_fields) @@ -53,30 +52,35 @@ class AppParameterApi(WebApiResource): app_model_config = app_model.app_model_config features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get( + "suggested_questions_after_answer", {"enabled": False} + ), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, } @@ -86,5 +90,5 @@ class AppMeta(WebApiResource): return AppService().get_app_meta(app_model) -api.add_resource(AppParameterApi, '/parameters') -api.add_resource(AppMeta, '/meta') \ No newline at end of file +api.add_resource(AppParameterApi, "/parameters") +api.add_resource(AppMeta, "/meta") diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 0e905f905a..d062d2893b 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -31,14 +31,10 @@ from services.errors.audio import ( class AudioApi(WebApiResource): def post(self, app_model: App, end_user): - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=end_user - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -70,34 +66,36 @@ class AudioApi(WebApiResource): class TextApi(WebApiResource): def post(self, app_model: App, end_user): from flask_restful import reqparse + try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get( - 'voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - end_user=end_user.external_user_id, - voice=voice, - text=text + app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text ) return response @@ -127,5 +125,5 @@ class TextApi(WebApiResource): raise InternalServerError() -api.add_resource(AudioApi, '/audio-to-text') -api.add_resource(TextApi, '/text-to-audio') +api.add_resource(AudioApi, "/audio-to-text") +api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 948d5fabb5..bd636a0485 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -28,30 +28,25 @@ from services.app_generate_service import AppGenerateService # define completion api for user class CompletionApi(WebApiResource): - def post(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming ) return helper.compact_generate_response(response) @@ -79,12 +74,12 @@ class CompletionApi(WebApiResource): class CompletionStopApi(WebApiResource): def post(self, app_model, end_user, task_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(WebApiResource): @@ -94,25 +89,21 @@ class ChatApi(WebApiResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming ) return helper.compact_generate_response(response) @@ -146,10 +137,10 @@ class ChatStopApi(WebApiResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/completion-messages') -api.add_resource(CompletionStopApi, '/completion-messages//stop') -api.add_resource(ChatApi, '/chat-messages') -api.add_resource(ChatStopApi, '/chat-messages//stop') +api.add_resource(CompletionApi, "/completion-messages") +api.add_resource(CompletionStopApi, "/completion-messages//stop") +api.add_resource(ChatApi, "/chat-messages") +api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 334ee382a2..6bbfa94c27 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -15,7 +15,6 @@ from services.web_conversation_service import WebConversationService class ConversationListApi(WebApiResource): - @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) @@ -23,26 +22,32 @@ class ConversationListApi(WebApiResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') - parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'], - required=False, default='-updated_at', location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() pinned = None - if 'pinned' in args and args['pinned'] is not None: - pinned = True if args['pinned'] == 'true' else False + if "pinned" in args and args["pinned"] is not None: + pinned = True if args["pinned"] == "true" else False try: return WebConversationService.pagination_by_last_id( app_model=app_model, user=end_user, - last_id=args['last_id'], - limit=args['limit'], + last_id=args["last_id"], + limit=args["limit"], invoke_from=InvokeFrom.WEB_APP, pinned=pinned, - sort_by=args['sort_by'] + sort_by=args["sort_by"], ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -65,7 +70,6 @@ class ConversationApi(WebApiResource): class ConversationRenameApi(WebApiResource): - @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -75,24 +79,17 @@ class ConversationRenameApi(WebApiResource): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: - return ConversationService.rename( - app_model, - conversation_id, - end_user, - args['name'], - args['auto_generate'] - ) + return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") class ConversationPinApi(WebApiResource): - def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: @@ -120,8 +117,8 @@ class ConversationUnPinApi(WebApiResource): return {"result": "success"} -api.add_resource(ConversationRenameApi, '/conversations//name', endpoint='web_conversation_name') -api.add_resource(ConversationListApi, '/conversations') -api.add_resource(ConversationApi, '/conversations/') -api.add_resource(ConversationPinApi, '/conversations//pin') -api.add_resource(ConversationUnPinApi, '/conversations//unpin') +api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="web_conversation_name") +api.add_resource(ConversationListApi, "/conversations") +api.add_resource(ConversationApi, "/conversations/") +api.add_resource(ConversationPinApi, "/conversations//pin") +api.add_resource(ConversationUnPinApi, "/conversations//unpin") diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index bc87f51051..2f6bb39cf2 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -2,122 +2,126 @@ from libs.exception import BaseHTTPException class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Please check if your Completion app mode matches the right API route." code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "Please check if your app mode matches the right API route." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Please check if your Workflow app mode matches the right API route." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted OpenAI has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class AppMoreLikeThisDisabledError(BaseHTTPException): - error_code = 'app_more_like_this_disabled' + error_code = "app_more_like_this_disabled" description = "The 'More like this' feature is disabled. Please refresh your page." code = 403 class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): - error_code = 'app_suggested_questions_after_answer_disabled' + error_code = "app_suggested_questions_after_answer_disabled" description = "The 'Suggested Questions After Answer' feature is disabled. Please refresh your page." code = 403 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class WebSSOAuthRequiredError(BaseHTTPException): - error_code = 'web_sso_auth_required' + error_code = "web_sso_auth_required" description = "Web SSO authentication required." code = 401 diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py index 69b38faaf6..0563ed2238 100644 --- a/api/controllers/web/feature.py +++ b/api/controllers/web/feature.py @@ -9,4 +9,4 @@ class SystemFeatureApi(Resource): return FeatureService.get_system_features().model_dump() -api.add_resource(SystemFeatureApi, '/system-features') +api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py index ca83f6037a..253b1d511c 100644 --- a/api/controllers/web/file.py +++ b/api/controllers/web/file.py @@ -10,14 +10,13 @@ from services.file_service import FileService class FileApi(WebApiResource): - @marshal_with(file_fields) def post(self, app_model, end_user): # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: @@ -32,4 +31,4 @@ class FileApi(WebApiResource): return upload_file, 201 -api.add_resource(FileApi, '/files/upload') +api.add_resource(FileApi, "/files/upload") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 865d2270ad..56aaaa930a 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -33,48 +33,46 @@ from services.message_service import MessageService class MessageListApi(WebApiResource): - feedback_fields = { - 'rating': fields.String - } + feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(message_infinite_scroll_pagination_fields) @@ -84,14 +82,15 @@ class MessageListApi(WebApiResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, end_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: @@ -103,29 +102,31 @@ class MessageFeedbackApi(WebApiResource): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args['rating']) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageMoreLikeThisApi(WebApiResource): def get(self, app_model, end_user, message_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + parser.add_argument( + "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" + ) args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate_more_like_this( @@ -133,7 +134,7 @@ class MessageMoreLikeThisApi(WebApiResource): user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + streaming=streaming, ) return helper.compact_generate_response(response) @@ -166,10 +167,7 @@ class MessageSuggestedQuestionApi(WebApiResource): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=end_user, - message_id=message_id, - invoke_from=InvokeFrom.WEB_APP + app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP ) except MessageNotExistsError: raise NotFound("Message not found") @@ -189,10 +187,10 @@ class MessageSuggestedQuestionApi(WebApiResource): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} -api.add_resource(MessageListApi, '/messages') -api.add_resource(MessageFeedbackApi, '/messages//feedbacks') -api.add_resource(MessageMoreLikeThisApi, '/messages//more-like-this') -api.add_resource(MessageSuggestedQuestionApi, '/messages//suggested-questions') +api.add_resource(MessageListApi, "/messages") +api.add_resource(MessageFeedbackApi, "/messages//feedbacks") +api.add_resource(MessageMoreLikeThisApi, "/messages//more-like-this") +api.add_resource(MessageSuggestedQuestionApi, "/messages//suggested-questions") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index ccc8683a79..a01ffd8612 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -9,37 +9,37 @@ from controllers.web.error import WebSSOAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService class PassportResource(Resource): """Base resource for passport.""" + def get(self): - system_features = FeatureService.get_system_features() - if system_features.sso_enforced_for_web: - raise WebSSOAuthRequiredError() - - app_code = request.headers.get('X-App-Code') + app_code = request.headers.get("X-App-Code") if app_code is None: - raise Unauthorized('X-App-Code header is missing.') + raise Unauthorized("X-App-Code header is missing.") + + if system_features.sso_enforced_for_web: + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) + if app_web_sso_enabled: + raise WebSSOAuthRequiredError() # get site from db and check if it is normal - site = db.session.query(Site).filter( - Site.code == app_code, - Site.status == 'normal' - ).first() + site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() if not site: raise NotFound() # get app from db and check if it is normal and enable_site app_model = db.session.query(App).filter(App.id == site.app_id).first() - if not app_model or app_model.status != 'normal' or not app_model.enable_site: + if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() end_user = EndUser( tenant_id=app_model.tenant_id, app_id=app_model.id, - type='browser', + type="browser", is_anonymous=True, session_id=generate_session_id(), ) @@ -49,20 +49,20 @@ class PassportResource(Resource): payload = { "iss": site.app_id, - 'sub': 'Web API Passport', - 'app_id': site.app_id, - 'app_code': app_code, - 'end_user_id': end_user.id, + "sub": "Web API Passport", + "app_id": site.app_id, + "app_code": app_code, + "end_user_id": end_user.id, } tk = PassportService().issue(payload) return { - 'access_token': tk, + "access_token": tk, } -api.add_resource(PassportResource, '/passport') +api.add_resource(PassportResource, "/passport") def generate_session_id(): @@ -71,7 +71,6 @@ def generate_session_id(): """ while True: session_id = str(uuid.uuid4()) - existing_count = db.session.query(EndUser) \ - .filter(EndUser.session_id == session_id).count() + existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count() if existing_count == 0: return session_id diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index e17869ffdb..8253f5fc57 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -10,67 +10,65 @@ from libs.helper import TimestampField, uuid_value from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} message_fields = { - 'id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String, - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String, + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "created_at": TimestampField, } class SavedMessageListApi(WebApiResource): saved_message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - return SavedMessageService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit']) + return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) def post(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='json') + parser.add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() try: - SavedMessageService.save(app_model, end_user, args['message_id']) + SavedMessageService.save(app_model, end_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class SavedMessageApi(WebApiResource): def delete(self, app_model, end_user, message_id): message_id = str(message_id) - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() SavedMessageService.delete(app_model, end_user, message_id) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(SavedMessageListApi, '/saved-messages') -api.add_resource(SavedMessageApi, '/saved-messages/') +api.add_resource(SavedMessageListApi, "/saved-messages") +api.add_resource(SavedMessageApi, "/saved-messages/") diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 0f4a7cabe5..2b4d0e7630 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,3 @@ - from flask_restful import fields, marshal_with from werkzeug.exceptions import Forbidden @@ -16,41 +15,41 @@ class AppSiteApi(WebApiResource): """Resource for app sites.""" model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), - 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), - 'more_like_this': fields.Raw(attribute='more_like_this_dict'), - 'model': fields.Raw(attribute='model_dict'), - 'user_input_form': fields.Raw(attribute='user_input_form_list'), - 'pre_prompt': fields.String, + "opening_statement": fields.String, + "suggested_questions": fields.Raw(attribute="suggested_questions_list"), + "suggested_questions_after_answer": fields.Raw(attribute="suggested_questions_after_answer_dict"), + "more_like_this": fields.Raw(attribute="more_like_this_dict"), + "model": fields.Raw(attribute="model_dict"), + "user_input_form": fields.Raw(attribute="user_input_form_list"), + "pre_prompt": fields.String, } site_fields = { - 'title': fields.String, - 'chat_color_theme': fields.String, - 'chat_color_theme_inverted': fields.Boolean, - 'icon_type': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'icon_url': AppIconUrlField, - 'description': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'default_language': fields.String, - 'prompt_public': fields.Boolean, - 'show_workflow_steps': fields.Boolean, + "title": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "description": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "default_language": fields.String, + "prompt_public": fields.Boolean, + "show_workflow_steps": fields.Boolean, } app_fields = { - 'app_id': fields.String, - 'end_user_id': fields.String, - 'enable_site': fields.Boolean, - 'site': fields.Nested(site_fields), - 'model_config': fields.Nested(model_config_fields, allow_null=True), - 'plan': fields.String, - 'can_replace_logo': fields.Boolean, - 'custom_config': fields.Raw(attribute='custom_config'), + "app_id": fields.String, + "end_user_id": fields.String, + "enable_site": fields.Boolean, + "site": fields.Nested(site_fields), + "model_config": fields.Nested(model_config_fields, allow_null=True), + "plan": fields.String, + "can_replace_logo": fields.Boolean, + "custom_config": fields.Raw(attribute="custom_config"), } @marshal_with(app_fields) @@ -70,7 +69,7 @@ class AppSiteApi(WebApiResource): return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) -api.add_resource(AppSiteApi, '/site') +api.add_resource(AppSiteApi, "/site") class AppSiteInfo: @@ -88,9 +87,13 @@ class AppSiteInfo: if can_replace_logo: base_url = dify_config.FILES_URL - remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) - replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None + remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) + replace_webapp_logo = ( + f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" + if tenant.custom_config_dict.get("replace_webapp_logo") + else None + ) self.custom_config = { - 'remove_webapp_brand': remove_webapp_brand, - 'replace_webapp_logo': replace_webapp_logo, + "remove_webapp_brand": remove_webapp_brand, + "replace_webapp_logo": replace_webapp_logo, } diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 77c468e417..55b0c3e2ab 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -33,17 +33,13 @@ class WorkflowRunApi(WebApiResource): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=True + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=True ) return helper.compact_generate_response(response) @@ -73,10 +69,8 @@ class WorkflowTaskStopApi(WebApiResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(WorkflowRunApi, '/workflows/run') -api.add_resource(WorkflowTaskStopApi, '/workflows/tasks//stop') +api.add_resource(WorkflowRunApi, "/workflows/run") +api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index f5ab49d7e1..93dc691d62 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -8,6 +8,7 @@ from controllers.web.error import WebSSOAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -18,7 +19,9 @@ def validate_jwt_token(view=None): app_model, end_user = decode_jwt_token() return view(app_model, end_user, *args, **kwargs) + return decorated + if view: return decorator(view) return decorator @@ -26,56 +29,62 @@ def validate_jwt_token(view=None): def decode_jwt_token(): system_features = FeatureService.get_system_features() - + app_code = request.headers.get("X-App-Code") try: - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if auth_header is None: - raise Unauthorized('Authorization header is missing.') + raise Unauthorized("Authorization header is missing.") - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, tk = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(tk) - app_code = decoded.get('app_code') - app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() + app_code = decoded.get("app_code") + app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first() site = db.session.query(Site).filter(Site.code == app_code).first() if not app_model: raise NotFound() if not app_code or not site: - raise BadRequest('Site URL is no longer valid.') + raise BadRequest("Site URL is no longer valid.") if app_model.enable_site is False: - raise BadRequest('Site is disabled.') - end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() + raise BadRequest("Site is disabled.") + end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() if not end_user: raise NotFound() - _validate_web_sso_token(decoded, system_features) + _validate_web_sso_token(decoded, system_features, app_code) return app_model, end_user except Unauthorized as e: if system_features.sso_enforced_for_web: - raise WebSSOAuthRequiredError() + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) + if app_web_sso_enabled: + raise WebSSOAuthRequiredError() raise Unauthorized(e.description) -def _validate_web_sso_token(decoded, system_features): +def _validate_web_sso_token(decoded, system_features, app_code): + app_web_sso_enabled = False + # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login if system_features.sso_enforced_for_web: - source = decoded.get('token_source') - if not source or source != 'sso': - raise WebSSOAuthRequiredError() + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) + if app_web_sso_enabled: + source = decoded.get("token_source") + if not source or source != "sso": + raise WebSSOAuthRequiredError() # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login - if not system_features.sso_enforced_for_web: - source = decoded.get('token_source') - if source and source == 'sso': - raise Unauthorized('sso token expired.') + if not system_features.sso_enforced_for_web or not app_web_sso_enabled: + source = decoded.get("token_source") + if source and source == "sso": + raise Unauthorized("sso token expired.") class WebApiResource(Resource): diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 19ce401999..81be1a06a7 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,6 +1,6 @@ import base64 +import io import json -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -18,6 +18,7 @@ from anthropic.types import ( ) from anthropic.types.beta.tools import ToolsBetaMessage from httpx import Timeout +from PIL import Image from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -462,7 +463,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 335fa493cd..3f7266f600 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -1,8 +1,8 @@ # standard import import base64 +import io import json import logging -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -17,6 +17,7 @@ from botocore.exceptions import ( ServiceNotInRegionError, UnknownServiceError, ) +from PIL.Image import Image # local import from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -381,9 +382,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel): try: url = message_content.data image_content = requests.get(url).content - if '?' in url: - url = url.split('?')[0] - mime_type, _ = mimetypes.guess_type(url) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index ebcd0af35b..84241fb6c8 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,7 +1,7 @@ import base64 +import io import json import logging -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -12,6 +12,7 @@ import google.generativeai.client as client import requests from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types.content_types import to_part +from PIL import Image from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -371,7 +372,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel): # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index 8901549110..1a7368a2cf 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -1,4 +1,5 @@ import base64 +import io import json import logging from collections.abc import Generator @@ -18,6 +19,7 @@ from anthropic.types import ( ) from google.cloud import aiplatform from google.oauth2 import service_account +from PIL import Image from vertexai.generative_models import HarmBlockThreshold, HarmCategory from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -332,7 +334,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index 471cb3c94e..5100494e58 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -1,6 +1,25 @@ import re -from collections.abc import Callable, Generator -from typing import cast +from collections.abc import Generator +from typing import Optional, cast + +from volcenginesdkarkruntime import Ark +from volcenginesdkarkruntime.types.chat import ( + ChatCompletion, + ChatCompletionAssistantMessageParam, + ChatCompletionChunk, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam, + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, +) +from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL +from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function +from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse +from volcenginesdkarkruntime.types.shared_params import FunctionDefinition from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -12,123 +31,171 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService -class MaaSClient(MaasService): - def __init__(self, host: str, region: str): +class ArkClientV3: + endpoint_id: Optional[str] = None + ark: Optional[Ark] = None + + def __init__(self, *args, **kwargs): + self.ark = Ark(*args, **kwargs) self.endpoint_id = None - super().__init__(host, region) - - def set_endpoint_id(self, endpoint_id: str): - self.endpoint_id = endpoint_id - - @classmethod - def from_credential(cls, credentials: dict) -> 'MaaSClient': - host = credentials['api_endpoint_host'] - region = credentials['volc_region'] - ak = credentials['volc_access_key_id'] - sk = credentials['volc_secret_access_key'] - endpoint_id = credentials['endpoint_id'] - - client = cls(host, region) - client.set_endpoint_id(endpoint_id) - client.set_ak(ak) - client.set_sk(sk) - return client - - def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: - req = { - 'parameters': params, - 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], - **extra_model_kwargs, - } - if not stream: - return super().chat( - self.endpoint_id, - req, - ) - return super().stream_chat( - self.endpoint_id, - req, - ) - - def embeddings(self, texts: list[str]) -> dict: - req = { - 'input': texts - } - return super().embeddings(self.endpoint_id, req) @staticmethod - def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: + def is_legacy(credentials: dict) -> bool: + if ArkClientV3.is_compatible_with_legacy(credentials): + return False + sdk_version = credentials.get("sdk_version", "v2") + return sdk_version != "v3" + + @staticmethod + def is_compatible_with_legacy(credentials: dict) -> bool: + sdk_version = credentials.get("sdk_version") + endpoint = credentials.get("api_endpoint_host") + return sdk_version is None and endpoint == "maas-api.ml-platform-cn-beijing.volces.com" + + @classmethod + def from_credentials(cls, credentials): + """Initialize the client using the credentials provided.""" + args = { + "base_url": credentials['api_endpoint_host'], + "region": credentials['volc_region'], + "ak": credentials['volc_access_key_id'], + "sk": credentials['volc_secret_access_key'], + } + if cls.is_compatible_with_legacy(credentials): + args["base_url"] = "https://ark.cn-beijing.volces.com/api/v3" + + client = ArkClientV3( + **args + ) + client.endpoint_id = credentials['endpoint_id'] + return client + + @staticmethod + def convert_prompt_message(message: PromptMessage) -> ChatCompletionMessageParam: + """Converts a PromptMessage to a ChatCompletionMessageParam""" if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": ChatRole.USER, - "content": message.content} + content = message.content else: content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - raise ValueError( - 'Content object type only support image_url') + content.append(ChatCompletionContentPartTextParam( + text=message_content.text, + type='text', + )) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast( ImagePromptMessageContent, message_content) image_data = re.sub( r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) - content.append({ - 'type': 'image_url', - 'image_url': { - 'url': '', - 'image_bytes': image_data, - 'detail': message_content.detail, - } - }) - - message_dict = {'role': ChatRole.USER, 'content': content} + content.append(ChatCompletionContentPartImageParam( + image_url=ImageURL( + url=image_data, + detail=message_content.detail.value, + ), + type='image_url', + )) + message_dict = ChatCompletionUserMessageParam( + role='user', + content=content + ) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - message_dict = {'role': ChatRole.ASSISTANT, - 'content': message.content} - if message.tool_calls: - message_dict['tool_calls'] = [ - { - 'name': call.function.name, - 'arguments': call.function.arguments - } for call in message.tool_calls + message_dict = ChatCompletionAssistantMessageParam( + content=message.content, + role='assistant', + tool_calls=None if not message.tool_calls else [ + ChatCompletionMessageToolCallParam( + id=call.id, + function=Function( + name=call.function.name, + arguments=call.function.arguments + ), + type='function' + ) for call in message.tool_calls ] + ) elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = {'role': ChatRole.SYSTEM, - 'content': message.content} + message_dict = ChatCompletionSystemMessageParam( + content=message.content, + role='system' + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = {'role': ChatRole.FUNCTION, - 'content': message.content, - 'name': message.tool_call_id} + message_dict = ChatCompletionToolMessageParam( + content=message.content, + role='tool', + tool_call_id=message.tool_call_id + ) else: raise ValueError(f"Got unknown PromptMessage type {message}") return message_dict @staticmethod - def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: - try: - resp = fn() - except MaasException as e: - raise wrap_error(e) + def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam: + return ChatCompletionToolParam( + type='function', + function=FunctionDefinition( + name=message.name, + description=message.description, + parameters=message.parameters, + ) + ) - return resp + def chat(self, messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> ChatCompletion: + """Block chat""" + return self.ark.chat.completions.create( + model=self.endpoint_id, + messages=[self.convert_prompt_message(message) for message in messages], + tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None, + stop=stop, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + top_p=top_p, + temperature=temperature, + ) - @staticmethod - def transform_tool_prompt_to_maas_config(tool: PromptMessageTool): - return { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters, - } - } + def stream_chat(self, messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> Generator[ChatCompletionChunk]: + """Stream chat""" + chunks = self.ark.chat.completions.create( + stream=True, + model=self.endpoint_id, + messages=[self.convert_prompt_message(message) for message in messages], + tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None, + stop=stop, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + top_p=top_p, + temperature=temperature, + ) + for chunk in chunks: + if not chunk.choices: + continue + yield chunk + + def embeddings(self, texts: list[str]) -> CreateEmbeddingResponse: + return self.ark.embeddings.create(model=self.endpoint_id, input=texts) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py new file mode 100644 index 0000000000..1978c11680 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py @@ -0,0 +1,134 @@ +import re +from collections.abc import Callable, Generator +from typing import cast + +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService + + +class MaaSClient(MaasService): + def __init__(self, host: str, region: str): + self.endpoint_id = None + super().__init__(host, region) + + def set_endpoint_id(self, endpoint_id: str): + self.endpoint_id = endpoint_id + + @classmethod + def from_credential(cls, credentials: dict) -> 'MaaSClient': + host = credentials['api_endpoint_host'] + region = credentials['volc_region'] + ak = credentials['volc_access_key_id'] + sk = credentials['volc_secret_access_key'] + endpoint_id = credentials['endpoint_id'] + + client = cls(host, region) + client.set_endpoint_id(endpoint_id) + client.set_ak(ak) + client.set_sk(sk) + return client + + def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: + req = { + 'parameters': params, + 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], + **extra_model_kwargs, + } + if not stream: + return super().chat( + self.endpoint_id, + req, + ) + return super().stream_chat( + self.endpoint_id, + req, + ) + + def embeddings(self, texts: list[str]) -> dict: + req = { + 'input': texts + } + return super().embeddings(self.endpoint_id, req) + + @staticmethod + def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": ChatRole.USER, + "content": message.content} + else: + content = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + raise ValueError( + 'Content object type only support image_url') + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast( + ImagePromptMessageContent, message_content) + image_data = re.sub( + r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) + content.append({ + 'type': 'image_url', + 'image_url': { + 'url': '', + 'image_bytes': image_data, + 'detail': message_content.detail, + } + }) + + message_dict = {'role': ChatRole.USER, 'content': content} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {'role': ChatRole.ASSISTANT, + 'content': message.content} + if message.tool_calls: + message_dict['tool_calls'] = [ + { + 'name': call.function.name, + 'arguments': call.function.arguments + } for call in message.tool_calls + ] + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {'role': ChatRole.SYSTEM, + 'content': message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {'role': ChatRole.FUNCTION, + 'content': message.content, + 'name': message.tool_call_id} + else: + raise ValueError(f"Got unknown PromptMessage type {message}") + + return message_dict + + @staticmethod + def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: + try: + resp = fn() + except MaasException as e: + raise wrap_error(e) + + return resp + + @staticmethod + def transform_tool_prompt_to_maas_config(tool: PromptMessageTool): + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + } + } diff --git a/api/core/model_runtime/model_providers/volcengine_maas/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py similarity index 97% rename from api/core/model_runtime/model_providers/volcengine_maas/errors.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 63397a456e..21ffaf1258 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -1,4 +1,4 @@ -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException class ClientSDKRequestError(MaasException): diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/__init__.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index add5822bef..996c66e604 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -1,8 +1,10 @@ import logging from collections.abc import Generator +from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk + from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -27,19 +29,21 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient -from core.model_runtime.model_providers.volcengine_maas.errors import ( +from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3 +from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, + MaasException, RateLimitErrors, ServerUnavailableErrors, ) from core.model_runtime.model_providers.volcengine_maas.llm.models import ( get_model_config, get_v2_req_params, + get_v3_req_params, ) -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException logger = logging.getLogger(__name__) @@ -49,13 +53,20 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: - return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + if ArkClientV3.is_legacy(credentials): + return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate credentials """ - # ping + if ArkClientV3.is_legacy(credentials): + return self._validate_credentials_v2(credentials) + return self._validate_credentials_v3(credentials) + + @staticmethod + def _validate_credentials_v2(credentials: dict) -> None: client = MaaSClient.from_credential(credentials) try: client.chat( @@ -70,18 +81,24 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): except MaasException as e: raise CredentialsValidateFailedError(e.message) + @staticmethod + def _validate_credentials_v3(credentials: dict) -> None: + client = ArkClientV3.from_credentials(credentials) + try: + client.chat(max_tokens=16, temperature=0.7, top_p=0.9, + messages=[UserPromptMessage(content='ping\nAnswer: ')], ) + except Exception as e: + raise CredentialsValidateFailedError(e) + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool] | None = None) -> int: - if len(prompt_messages) == 0: + if ArkClientV3.is_legacy(credentials): + return self._get_num_tokens_v2(prompt_messages) + return self._get_num_tokens_v3(prompt_messages) + + def _get_num_tokens_v2(self, messages: list[PromptMessage]) -> int: + if len(messages) == 0: return 0 - return self._num_tokens_from_messages(prompt_messages) - - def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: - """ - Calculate num tokens. - - :param messages: messages - """ num_tokens = 0 messages_dict = [ MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] @@ -92,9 +109,22 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): return num_tokens - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int: + if len(messages) == 0: + return 0 + num_tokens = 0 + messages_dict = [ + ArkClientV3.convert_prompt_message(m) for m in messages] + for message in messages_dict: + for key, value in message.items(): + num_tokens += self._get_num_tokens_by_gpt2(str(key)) + num_tokens += self._get_num_tokens_by_gpt2(str(value)) + + return num_tokens + + def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: client = MaaSClient.from_credential(credentials) @@ -106,77 +136,151 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): ] resp = MaaSClient.wrap_exception( lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) - if not stream: - return self._handle_chat_response(model, credentials, prompt_messages, resp) - return self._handle_stream_chat_response(model, credentials, prompt_messages, resp) - def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator: - for index, r in enumerate(resp): - choices = r['choices'] + def _handle_stream_chat_response() -> Generator: + for index, r in enumerate(resp): + choices = r['choices'] + if not choices: + continue + choice = choices[0] + message = choice['message'] + usage = None + if r.get('usage'): + usage = self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=r['usage']['prompt_tokens'], + completion_tokens=r['usage']['completion_tokens'] + ) + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=AssistantPromptMessage( + content=message['content'] if message['content'] else '', + tool_calls=[] + ), + usage=usage, + finish_reason=choice.get('finish_reason'), + ), + ) + + def _handle_chat_response() -> LLMResult: + choices = resp['choices'] if not choices: - continue + raise ValueError("No choices found") + choice = choices[0] message = choice['message'] - usage = None - if r.get('usage'): - usage = self._calc_usage(model, credentials, r['usage']) - yield LLMResultChunk( + + # parse tool calls + tool_calls = [] + if message['tool_calls']: + for call in message['tool_calls']: + tool_call = AssistantPromptMessage.ToolCall( + id=call['function']['name'], + type=call['type'], + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=call['function']['name'], + arguments=call['function']['arguments'] + ) + ) + tool_calls.append(tool_call) + + usage = resp['usage'] + return LLMResult( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=[] - ), - usage=usage, - finish_reason=choice.get('finish_reason'), + message=AssistantPromptMessage( + content=message['content'] if message['content'] else '', + tool_calls=tool_calls, ), + usage=self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=usage['prompt_tokens'], + completion_tokens=usage['completion_tokens'] + ), ) - def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult: - choices = resp['choices'] - if not choices: - return - choice = choices[0] - message = choice['message'] + if not stream: + return _handle_chat_response() + return _handle_stream_chat_response() - # parse tool calls - tool_calls = [] - if message['tool_calls']: - for call in message['tool_calls']: - tool_call = AssistantPromptMessage.ToolCall( - id=call['function']['name'], - type=call['type'], - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=call['function']['name'], - arguments=call['function']['arguments'] - ) + def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ + -> LLMResult | Generator: + + client = ArkClientV3.from_credentials(credentials) + req_params = get_v3_req_params(credentials, model_parameters, stop) + if tools: + req_params['tools'] = tools + + def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator: + for chunk in chunks: + if not chunk.choices: + continue + choice = chunk.choices[0] + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=choice.index, + message=AssistantPromptMessage( + content=choice.delta.content, + tool_calls=[] + ), + usage=self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens + ) if chunk.usage else None, + finish_reason=choice.finish_reason, + ), ) - tool_calls.append(tool_call) - return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=tool_calls, - ), - usage=self._calc_usage(model, credentials, resp['usage']), - ) + def _handle_chat_response(resp: ChatCompletion) -> LLMResult: + choice = resp.choices[0] + message = choice.message + # parse tool calls + tool_calls = [] + if message.tool_calls: + for call in message.tool_calls: + tool_call = AssistantPromptMessage.ToolCall( + id=call.id, + type=call.type, + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=call.function.name, + arguments=call.function.arguments + ) + ) + tool_calls.append(tool_call) - def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage: - return self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=usage['prompt_tokens'], - completion_tokens=usage['completion_tokens'] - ) + usage = resp.usage + return LLMResult( + model=model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=message.content if message.content else "", + tool_calls=tool_calls, + ), + usage=self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens + ), + ) + + if not stream: + resp = client.chat(prompt_messages, **req_params) + return _handle_chat_response(resp) + + chunks = client.stream_chat(prompt_messages, **req_params) + return _handle_stream_chat_response(chunks) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ used to define customizable model schema """ model_config = get_model_config(credentials) - + rules = [ ParameterRule( name='temperature', @@ -212,7 +316,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): use_template='presence_penalty', label=I18nObject( en_US='Presence Penalty', - zh_Hans= '存在惩罚', + zh_Hans='存在惩罚', ), min=-2.0, max=2.0, @@ -222,8 +326,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): type=ParameterType.FLOAT, use_template='frequency_penalty', label=I18nObject( - en_US= 'Frequency Penalty', - zh_Hans= '频率惩罚', + en_US='Frequency Penalty', + zh_Hans='频率惩罚', ), min=-2.0, max=2.0, @@ -245,7 +349,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): model_properties = {} model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value - + entity = AIModelEntity( model=model, label=I18nObject( diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index c5e53d8955..4e2c66a066 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -5,10 +5,11 @@ from core.model_runtime.entities.model_entities import ModelFeature class ModelProperties(BaseModel): - context_size: int - max_tokens: int + context_size: int + max_tokens: int mode: LLMMode + class ModelConfig(BaseModel): properties: ModelProperties features: list[ModelFeature] @@ -24,23 +25,23 @@ configs: dict[str, ModelConfig] = { features=[ModelFeature.TOOL_CALL] ), 'Doubao-pro-32k': ModelConfig( - properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL] ), 'Doubao-lite-32k': ModelConfig( - properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL] ), 'Doubao-pro-128k': ModelConfig( - properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL] ), 'Doubao-lite-128k': ModelConfig( - properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[ModelFeature.TOOL_CALL] ), 'Skylark2-pro-4k': ModelConfig( - properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT), + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), features=[] ), 'Llama3-8B': ModelConfig( @@ -77,23 +78,24 @@ configs: dict[str, ModelConfig] = { ) } -def get_model_config(credentials: dict)->ModelConfig: + +def get_model_config(credentials: dict) -> ModelConfig: base_model = credentials.get('base_model_name', '') model_configs = configs.get(base_model) if not model_configs: return ModelConfig( - properties=ModelProperties( + properties=ModelProperties( context_size=int(credentials.get('context_size', 0)), max_tokens=int(credentials.get('max_tokens', 0)), - mode= LLMMode.value_of(credentials.get('mode', 'chat')), + mode=LLMMode.value_of(credentials.get('mode', 'chat')), ), features=[] ) return model_configs -def get_v2_req_params(credentials: dict, model_parameters: dict, - stop: list[str] | None=None): +def get_v2_req_params(credentials: dict, model_parameters: dict, + stop: list[str] | None = None): req_params = {} # predefined properties model_configs = get_model_config(credentials) @@ -116,8 +118,36 @@ def get_v2_req_params(credentials: dict, model_parameters: dict, if model_parameters.get('frequency_penalty'): req_params['frequency_penalty'] = model_parameters.get( 'frequency_penalty') - + if stop: req_params['stop'] = stop - return req_params \ No newline at end of file + return req_params + + +def get_v3_req_params(credentials: dict, model_parameters: dict, + stop: list[str] | None = None): + req_params = {} + # predefined properties + model_configs = get_model_config(credentials) + if model_configs: + req_params['max_tokens'] = model_configs.properties.max_tokens + + # model parameters + if model_parameters.get('max_tokens'): + req_params['max_tokens'] = model_parameters.get('max_tokens') + if model_parameters.get('temperature'): + req_params['temperature'] = model_parameters.get('temperature') + if model_parameters.get('top_p'): + req_params['top_p'] = model_parameters.get('top_p') + if model_parameters.get('presence_penalty'): + req_params['presence_penalty'] = model_parameters.get( + 'presence_penalty') + if model_parameters.get('frequency_penalty'): + req_params['frequency_penalty'] = model_parameters.get( + 'frequency_penalty') + + if stop: + req_params['stop'] = stop + + return req_params diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index 2d8f972b94..74cf26247c 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -2,26 +2,29 @@ from pydantic import BaseModel class ModelProperties(BaseModel): - context_size: int - max_chunks: int + context_size: int + max_chunks: int + class ModelConfig(BaseModel): properties: ModelProperties + ModelConfigs = { 'Doubao-embedding': ModelConfig( - properties=ModelProperties(context_size=4096, max_chunks=1) + properties=ModelProperties(context_size=4096, max_chunks=32) ), } -def get_model_config(credentials: dict)->ModelConfig: + +def get_model_config(credentials: dict) -> ModelConfig: base_model = credentials.get('base_model_name', '') model_configs = ModelConfigs.get(base_model) if not model_configs: return ModelConfig( - properties=ModelProperties( + properties=ModelProperties( context_size=int(credentials.get('context_size', 0)), max_chunks=int(credentials.get('max_chunks', 0)), ) ) - return model_configs \ No newline at end of file + return model_configs diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index 8ac632369e..d54aeeb0b1 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -22,16 +22,17 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient -from core.model_runtime.model_providers.volcengine_maas.errors import ( +from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3 +from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, + MaasException, RateLimitErrors, ServerUnavailableErrors, ) from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): @@ -51,6 +52,14 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): :param user: unique user id :return: embeddings result """ + if ArkClientV3.is_legacy(credentials): + return self._generate_v2(model, credentials, texts, user) + + return self._generate_v3(model, credentials, texts, user) + + def _generate_v2(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: client = MaaSClient.from_credential(credentials) resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts)) @@ -65,6 +74,23 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): return result + def _generate_v3(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + client = ArkClientV3.from_credentials(credentials) + resp = client.embeddings(texts) + + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=resp.usage.total_tokens) + + result = TextEmbeddingResult( + model=model, + embeddings=[v.embedding for v in resp.data], + usage=usage + ) + + return result + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -88,11 +114,22 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): :param credentials: model credentials :return: """ + if ArkClientV3.is_legacy(credentials): + return self._validate_credentials_v2(model, credentials) + return self._validate_credentials_v3(model, credentials) + + def _validate_credentials_v2(self, model: str, credentials: dict) -> None: try: self._invoke(model=model, credentials=credentials, texts=['ping']) except MaasException as e: raise CredentialsValidateFailedError(e.message) + def _validate_credentials_v3(self, model: str, credentials: dict) -> None: + try: + self._invoke(model=model, credentials=credentials, texts=['ping']) + except Exception as e: + raise CredentialsValidateFailedError(e) + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -116,9 +153,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): generate custom model entities from credentials """ model_config = get_model_config(credentials) - model_properties = {} - model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size - model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks + model_properties = { + ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size, + ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks + } entity = AIModelEntity( model=model, label=I18nObject(en_US=model), diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index fc6d231f8e..c970e3dafa 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -616,6 +616,7 @@ class DatasetRetrieval: for document in all_documents: if score_threshold is None or document.metadata['score'] >= score_threshold: filter_documents.append(document) + if not filter_documents: return [] filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True) diff --git a/api/core/tools/provider/builtin/siliconflow/_assets/icon.svg b/api/core/tools/provider/builtin/siliconflow/_assets/icon.svg new file mode 100644 index 0000000000..ad6b384f7a --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/_assets/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/siliconflow/siliconflow.py b/api/core/tools/provider/builtin/siliconflow/siliconflow.py new file mode 100644 index 0000000000..0df78280df --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/siliconflow.py @@ -0,0 +1,19 @@ +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class SiliconflowProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + url = "https://api.siliconflow.cn/v1/models" + headers = { + "accept": "application/json", + "authorization": f"Bearer {credentials.get('siliconFlow_api_key')}", + } + + response = requests.get(url, headers=headers) + if response.status_code != 200: + raise ToolProviderCredentialValidationError( + "SiliconFlow API key is invalid" + ) diff --git a/api/core/tools/provider/builtin/siliconflow/siliconflow.yaml b/api/core/tools/provider/builtin/siliconflow/siliconflow.yaml new file mode 100644 index 0000000000..46be99f262 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/siliconflow.yaml @@ -0,0 +1,21 @@ +identity: + author: hjlarry + name: siliconflow + label: + en_US: SiliconFlow + zh_CN: 硅基流动 + description: + en_US: The image generation API provided by SiliconFlow includes Flux and Stable Diffusion models. + zh_CN: 硅基流动提供的图片生成 API,包含 Flux 和 Stable Diffusion 模型。 + icon: icon.svg + tags: + - image +credentials_for_provider: + siliconFlow_api_key: + type: secret-input + required: true + label: + en_US: SiliconFlow API Key + placeholder: + en_US: Please input your SiliconFlow API key + url: https://cloud.siliconflow.cn/account/ak diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.py b/api/core/tools/provider/builtin/siliconflow/tools/flux.py new file mode 100644 index 0000000000..ed9f4be574 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.py @@ -0,0 +1,44 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +FLUX_URL = ( + "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image" +) + + +class FluxTool(BuiltinTool): + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + headers = { + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {self.runtime.credentials['siliconFlow_api_key']}", + } + + payload = { + "prompt": tool_parameters.get("prompt"), + "image_size": tool_parameters.get("image_size", "1024x1024"), + "seed": tool_parameters.get("seed"), + "num_inference_steps": tool_parameters.get("num_inference_steps", 20), + } + + response = requests.post(FLUX_URL, json=payload, headers=headers) + if response.status_code != 200: + return self.create_text_message(f"Got Error Response:{response.text}") + + res = response.json() + result = [self.create_json_message(res)] + for image in res.get("images", []): + result.append( + self.create_image_message( + image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value + ) + ) + return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml b/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml new file mode 100644 index 0000000000..2a0698700c --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml @@ -0,0 +1,73 @@ +identity: + name: flux + author: hjlarry + label: + en_US: Flux + icon: icon.svg +description: + human: + en_US: Generate image via SiliconFlow's flux schnell. + llm: This tool is used to generate image from prompt via SiliconFlow's flux schnell model. +parameters: + - name: prompt + type: string + required: true + label: + en_US: prompt + zh_Hans: 提示词 + human_description: + en_US: The text prompt used to generate the image. + zh_Hans: 用于生成图片的文字提示词 + llm_description: this prompt text will be used to generate image. + form: llm + - name: image_size + type: select + required: true + options: + - value: 1024x1024 + label: + en_US: 1024x1024 + - value: 768x1024 + label: + en_US: 768x1024 + - value: 576x1024 + label: + en_US: 576x1024 + - value: 512x1024 + label: + en_US: 512x1024 + - value: 1024x576 + label: + en_US: 1024x576 + - value: 768x512 + label: + en_US: 768x512 + default: 1024x1024 + label: + en_US: Choose Image Size + zh_Hans: 选择生成的图片大小 + form: form + - name: num_inference_steps + type: number + required: true + default: 20 + min: 1 + max: 100 + label: + en_US: Num Inference Steps + zh_Hans: 生成图片的步数 + form: form + human_description: + en_US: The number of inference steps to perform. More steps produce higher quality but take longer. + zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。 + - name: seed + type: number + min: 0 + max: 9999999999 + label: + en_US: Seed + zh_Hans: 种子 + human_description: + en_US: The same seed and prompt can produce similar images. + zh_Hans: 相同的种子和提示可以产生相似的图像。 + form: form diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py new file mode 100644 index 0000000000..e8134a6565 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py @@ -0,0 +1,51 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +SDURL = { + "sd_3": "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-3-medium/text-to-image", + "sd_xl": "https://api.siliconflow.cn/v1/stabilityai/stable-diffusion-xl-base-1.0/text-to-image", +} + + +class StableDiffusionTool(BuiltinTool): + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + headers = { + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {self.runtime.credentials['siliconFlow_api_key']}", + } + + model = tool_parameters.get("model", "sd_3") + url = SDURL.get(model) + + payload = { + "prompt": tool_parameters.get("prompt"), + "negative_prompt": tool_parameters.get("negative_prompt", ""), + "image_size": tool_parameters.get("image_size", "1024x1024"), + "batch_size": tool_parameters.get("batch_size", 1), + "seed": tool_parameters.get("seed"), + "guidance_scale": tool_parameters.get("guidance_scale", 7.5), + "num_inference_steps": tool_parameters.get("num_inference_steps", 20), + } + + response = requests.post(url, json=payload, headers=headers) + if response.status_code != 200: + return self.create_text_message(f"Got Error Response:{response.text}") + + res = response.json() + result = [self.create_json_message(res)] + for image in res.get("images", []): + result.append( + self.create_image_message( + image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value + ) + ) + return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml new file mode 100644 index 0000000000..abb541c485 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml @@ -0,0 +1,121 @@ +identity: + name: stable_diffusion + author: hjlarry + label: + en_US: Stable Diffusion + icon: icon.svg +description: + human: + en_US: Generate image via SiliconFlow's stable diffusion model. + llm: This tool is used to generate image from prompt via SiliconFlow's stable diffusion model. +parameters: + - name: prompt + type: string + required: true + label: + en_US: prompt + zh_Hans: 提示词 + human_description: + en_US: The text prompt used to generate the image. + zh_Hans: 用于生成图片的文字提示词 + llm_description: this prompt text will be used to generate image. + form: llm + - name: negative_prompt + type: string + label: + en_US: negative prompt + zh_Hans: 负面提示词 + human_description: + en_US: Describe what you don't want included in the image. + zh_Hans: 描述您不希望包含在图片中的内容。 + llm_description: Describe what you don't want included in the image. + form: llm + - name: model + type: select + required: true + options: + - value: sd_3 + label: + en_US: Stable Diffusion 3 + - value: sd_xl + label: + en_US: Stable Diffusion XL + default: sd_3 + label: + en_US: Choose Image Model + zh_Hans: 选择生成图片的模型 + form: form + - name: image_size + type: select + required: true + options: + - value: 1024x1024 + label: + en_US: 1024x1024 + - value: 1024x2048 + label: + en_US: 1024x2048 + - value: 1152x2048 + label: + en_US: 1152x2048 + - value: 1536x1024 + label: + en_US: 1536x1024 + - value: 1536x2048 + label: + en_US: 1536x2048 + - value: 2048x1152 + label: + en_US: 2048x1152 + default: 1024x1024 + label: + en_US: Choose Image Size + zh_Hans: 选择生成图片的大小 + form: form + - name: batch_size + type: number + required: true + default: 1 + min: 1 + max: 4 + label: + en_US: Number Images + zh_Hans: 生成图片的数量 + form: form + - name: guidance_scale + type: number + required: true + default: 7 + min: 0 + max: 100 + label: + en_US: Guidance Scale + zh_Hans: 与提示词紧密性 + human_description: + en_US: Classifier Free Guidance. How close you want the model to stick to your prompt when looking for a related image to show you. + zh_Hans: 无分类器引导。您希望模型在寻找相关图片向您展示时,与您的提示保持多紧密的关联度。 + form: form + - name: num_inference_steps + type: number + required: true + default: 20 + min: 1 + max: 100 + label: + en_US: Num Inference Steps + zh_Hans: 生成图片的步数 + human_description: + en_US: The number of inference steps to perform. More steps produce higher quality but take longer. + zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。 + form: form + - name: seed + type: number + min: 0 + max: 9999999999 + label: + en_US: Seed + zh_Hans: 种子 + human_description: + en_US: The same seed and prompt can produce similar images. + zh_Hans: 相同的种子和提示可以产生相似的图像。 + form: form diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 9cf5c505d1..1edc558676 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -20,11 +20,11 @@ if [[ "${MODE}" == "worker" ]]; then CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}" fi - exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel INFO \ + exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel ${LOG_LEVEL} \ -Q ${CELERY_QUEUES:-dataset,generation,mail,ops_trace,app_deletion} elif [[ "${MODE}" == "beat" ]]; then - exec celery -A app.celery beat --loglevel INFO + exec celery -A app.celery beat --loglevel ${LOG_LEVEL} else if [[ "${DEBUG}" == "true" ]]; then exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug diff --git a/api/poetry.lock b/api/poetry.lock index 0a8919a30a..a273b3c4d9 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -6143,6 +6143,19 @@ files = [ {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, + {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, + {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, + {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, ] [package.dependencies] @@ -8854,6 +8867,28 @@ files = [ {file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"}, ] +[[package]] +name = "volcengine-python-sdk" +version = "1.0.98" +description = "Volcengine SDK for Python" +optional = false +python-versions = "*" +files = [ + {file = "volcengine-python-sdk-1.0.98.tar.gz", hash = "sha256:1515e8d46cdcda387f9b45abbcaf0b04b982f7be68068de83f1e388281441784"}, +] + +[package.dependencies] +anyio = {version = ">=3.5.0,<5", optional = true, markers = "extra == \"ark\""} +certifi = ">=2017.4.17" +httpx = {version = ">=0.23.0,<1", optional = true, markers = "extra == \"ark\""} +pydantic = {version = ">=1.9.0,<3", optional = true, markers = "extra == \"ark\""} +python-dateutil = ">=2.1" +six = ">=1.10" +urllib3 = ">=1.23" + +[package.extras] +ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic (>=1.9.0,<3)"] + [[package]] name = "watchfiles" version = "0.23.0" @@ -9634,4 +9669,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "d7336115709114c2a4ff09b392f717e9c3547ae82b6a111d0c885c7a44269f02" +content-hash = "04f970820de691f40fc9fb30f5ff0618b0f1a04d3315b14467fb88e475fa1243" diff --git a/api/pyproject.toml b/api/pyproject.toml index e05a51dc13..e2d6704e8a 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -71,11 +71,8 @@ ignore = [ [tool.ruff.format] exclude = [ "core/**/*.py", - "controllers/**/*.py", "models/**/*.py", "migrations/**/*", - "services/**/*.py", - "tasks/**/*.py", ] [tool.pytest_env] @@ -191,6 +188,7 @@ zhipuai = "1.0.7" # Related transparent dependencies with pinned verion # required by main implementations ############################################################ +volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"} [tool.poetry.group.indriect.dependencies] kaleido = "0.2.1" rank-bm25 = "~0.2.2" diff --git a/api/services/__init__.py b/api/services/__init__.py index 6891436314..5163862cc1 100644 --- a/api/services/__init__.py +++ b/api/services/__init__.py @@ -1,3 +1,3 @@ from . import errors -__all__ = ['errors'] +__all__ = ["errors"] diff --git a/api/services/account_service.py b/api/services/account_service.py index d73cec2697..cd501c9792 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -39,12 +39,7 @@ from tasks.mail_reset_password_task import send_reset_password_mail_task class AccountService: - - reset_password_rate_limiter = RateLimiter( - prefix="reset_password_rate_limit", - max_attempts=5, - time_window=60 * 60 - ) + reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60) @staticmethod def load_user(user_id: str) -> None | Account: @@ -55,12 +50,15 @@ class AccountService: if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: raise Unauthorized("Account is banned or closed.") - current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() + current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by( + account_id=account.id, current=True + ).first() if current_tenant: account.current_tenant_id = current_tenant.tenant_id else: - available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \ - .order_by(TenantAccountJoin.id.asc()).first() + available_ta = ( + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + ) if not available_ta: return None @@ -74,14 +72,13 @@ class AccountService: return account - @staticmethod def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)): payload = { "user_id": account.id, "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp, "iss": dify_config.EDITION, - "sub": 'Console API Passport', + "sub": "Console API Passport", } token = PassportService().issue(payload) @@ -93,10 +90,10 @@ class AccountService: account = Account.query.filter_by(email=email).first() if not account: - raise AccountLoginError('Invalid email or password.') + raise AccountLoginError("Invalid email or password.") if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: - raise AccountLoginError('Account is banned or closed.') + raise AccountLoginError("Account is banned or closed.") if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value @@ -104,7 +101,7 @@ class AccountService: db.session.commit() if account.password is None or not compare_password(password, account.password, account.password_salt): - raise AccountLoginError('Invalid email or password.') + raise AccountLoginError("Invalid email or password.") return account @staticmethod @@ -129,11 +126,9 @@ class AccountService: return account @staticmethod - def create_account(email: str, - name: str, - interface_language: str, - password: Optional[str] = None, - interface_theme: str = 'light') -> Account: + def create_account( + email: str, name: str, interface_language: str, password: Optional[str] = None, interface_theme: str = "light" + ) -> Account: """create account""" account = Account() account.email = email @@ -155,7 +150,7 @@ class AccountService: account.interface_theme = interface_theme # Set timezone based on language - account.timezone = language_timezone_mapping.get(interface_language, 'UTC') + account.timezone = language_timezone_mapping.get(interface_language, "UTC") db.session.add(account) db.session.commit() @@ -166,8 +161,9 @@ class AccountService: """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(account_id=account.id, - provider=provider).first() + account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by( + account_id=account.id, provider=provider + ).first() if account_integrate: # If it exists, update the record @@ -176,15 +172,16 @@ class AccountService: account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) else: # If it does not exist, create a new record - account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id, - encrypted_token="") + account_integrate = AccountIntegrate( + account_id=account.id, provider=provider, open_id=open_id, encrypted_token="" + ) db.session.add(account_integrate) db.session.commit() - logging.info(f'Account {account.id} linked {provider} account {open_id}.') + logging.info(f"Account {account.id} linked {provider} account {open_id}.") except Exception as e: - logging.exception(f'Failed to link {provider} account {open_id} to Account {account.id}') - raise LinkAccountIntegrateError('Failed to link account.') from e + logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}") + raise LinkAccountIntegrateError("Failed to link account.") from e @staticmethod def close_account(account: Account) -> None: @@ -218,7 +215,7 @@ class AccountService: AccountService.update_last_login(account, ip_address=ip_address) exp = timedelta(days=30) token = AccountService.get_account_jwt_token(account, exp=exp) - redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds())) + redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds())) return token @staticmethod @@ -236,22 +233,18 @@ class AccountService: if cls.reset_password_rate_limiter.is_rate_limited(account.email): raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.") - token = TokenManager.generate_token(account, 'reset_password') - send_reset_password_mail_task.delay( - language=account.interface_language, - to=account.email, - token=token - ) + token = TokenManager.generate_token(account, "reset_password") + send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token) cls.reset_password_rate_limiter.increment_rate_limit(account.email) return token @classmethod def revoke_reset_password_token(cls, token: str): - TokenManager.revoke_token(token, 'reset_password') + TokenManager.revoke_token(token, "reset_password") @classmethod def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: - return TokenManager.get_token_data(token, 'reset_password') + return TokenManager.get_token_data(token, "reset_password") def _get_login_cache_key(*, account_id: str, token: str): @@ -259,7 +252,6 @@ def _get_login_cache_key(*, account_id: str, token: str): class TenantService: - @staticmethod def create_tenant(name: str) -> Tenant: """Create tenant""" @@ -275,31 +267,28 @@ class TenantService: @staticmethod def create_owner_tenant_if_not_exist(account: Account): """Create owner tenant if not exist""" - available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \ - .order_by(TenantAccountJoin.id.asc()).first() + available_ta = ( + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + ) if available_ta: return tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role='owner') + TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant db.session.commit() tenant_was_created.send(tenant) @staticmethod - def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') -> TenantAccountJoin: + def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin: """Create tenant member""" if role == TenantAccountJoinRole.OWNER.value: if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]): - logging.error(f'Tenant {tenant.id} has already an owner.') - raise Exception('Tenant already has an owner.') + logging.error(f"Tenant {tenant.id} has already an owner.") + raise Exception("Tenant already has an owner.") - ta = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=role - ) + ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) db.session.add(ta) db.session.commit() return ta @@ -307,9 +296,12 @@ class TenantService: @staticmethod def get_join_tenants(account: Account) -> list[Tenant]: """Get account join tenants""" - return db.session.query(Tenant).join( - TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id - ).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all() + return ( + db.session.query(Tenant) + .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) + .filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) + .all() + ) @staticmethod def get_current_tenant_by_account(account: Account): @@ -333,16 +325,23 @@ class TenantService: if tenant_id is None: raise ValueError("Tenant ID must be provided.") - tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter( - TenantAccountJoin.account_id == account.id, - TenantAccountJoin.tenant_id == tenant_id, - Tenant.status == TenantStatus.NORMAL, - ).first() + tenant_account_join = ( + db.session.query(TenantAccountJoin) + .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) + .filter( + TenantAccountJoin.account_id == account.id, + TenantAccountJoin.tenant_id == tenant_id, + Tenant.status == TenantStatus.NORMAL, + ) + .first() + ) if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: - TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False}) + TenantAccountJoin.query.filter( + TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id + ).update({"current": False}) tenant_account_join.current = True # Set the current tenant for the account account.current_tenant_id = tenant_account_join.tenant_id @@ -354,9 +353,7 @@ class TenantService: query = ( db.session.query(Account, TenantAccountJoin.role) .select_from(Account) - .join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .filter(TenantAccountJoin.tenant_id == tenant.id) ) @@ -375,11 +372,9 @@ class TenantService: query = ( db.session.query(Account, TenantAccountJoin.role) .select_from(Account) - .join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .filter(TenantAccountJoin.tenant_id == tenant.id) - .filter(TenantAccountJoin.role == 'dataset_operator') + .filter(TenantAccountJoin.role == "dataset_operator") ) # Initialize an empty list to store the updated accounts @@ -395,20 +390,25 @@ class TenantService: def has_roles(tenant: Tenant, roles: list[TenantAccountJoinRole]) -> bool: """Check if user has any of the given roles for a tenant""" if not all(isinstance(role, TenantAccountJoinRole) for role in roles): - raise ValueError('all roles must be TenantAccountJoinRole') + raise ValueError("all roles must be TenantAccountJoinRole") - return db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.role.in_([role.value for role in roles]) - ).first() is not None + return ( + db.session.query(TenantAccountJoin) + .filter( + TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]) + ) + .first() + is not None + ) @staticmethod def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]: """Get the role of the current account for a given tenant""" - join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.account_id == account.id - ).first() + join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .first() + ) return join.role if join else None @staticmethod @@ -420,29 +420,26 @@ class TenantService: def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None: """Check member permission""" perms = { - 'add': [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], - 'remove': [TenantAccountRole.OWNER], - 'update': [TenantAccountRole.OWNER] + "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], + "remove": [TenantAccountRole.OWNER], + "update": [TenantAccountRole.OWNER], } - if action not in ['add', 'remove', 'update']: + if action not in ["add", "remove", "update"]: raise InvalidActionError("Invalid action.") if member: if operator.id == member.id: raise CannotOperateSelfError("Cannot operate self.") - ta_operator = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=operator.id - ).first() + ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first() if not ta_operator or ta_operator.role not in perms[action]: - raise NoPermissionError(f'No permission to {action} member.') + raise NoPermissionError(f"No permission to {action} member.") @staticmethod def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None: """Remove member from tenant""" - if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'): + if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"): raise CannotOperateSelfError("Cannot operate self.") ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() @@ -455,23 +452,17 @@ class TenantService: @staticmethod def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None: """Update member role""" - TenantService.check_member_permission(tenant, operator, member, 'update') + TenantService.check_member_permission(tenant, operator, member, "update") - target_member_join = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=member.id - ).first() + target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first() if target_member_join.role == new_role: raise RoleAlreadyAssignedError("The provided role is already assigned to the member.") - if new_role == 'owner': + if new_role == "owner": # Find the current owner and change their role to 'admin' - current_owner_join = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - role='owner' - ).first() - current_owner_join.role = 'admin' + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + current_owner_join.role = "admin" # Update the role of the target member target_member_join.role = new_role @@ -480,8 +471,8 @@ class TenantService: @staticmethod def dissolve_tenant(tenant: Tenant, operator: Account) -> None: """Dissolve tenant""" - if not TenantService.check_member_permission(tenant, operator, operator, 'remove'): - raise NoPermissionError('No permission to dissolve tenant.') + if not TenantService.check_member_permission(tenant, operator, operator, "remove"): + raise NoPermissionError("No permission to dissolve tenant.") db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() db.session.delete(tenant) db.session.commit() @@ -494,10 +485,9 @@ class TenantService: class RegisterService: - @classmethod def _get_invitation_token_key(cls, token: str) -> str: - return f'member_invite:token:{token}' + return f"member_invite:token:{token}" @classmethod def setup(cls, email: str, name: str, password: str, ip_address: str) -> None: @@ -523,9 +513,7 @@ class RegisterService: TenantService.create_owner_tenant_if_not_exist(account) - dify_setup = DifySetup( - version=dify_config.CURRENT_VERSION - ) + dify_setup = DifySetup(version=dify_config.CURRENT_VERSION) db.session.add(dify_setup) db.session.commit() except Exception as e: @@ -535,34 +523,35 @@ class RegisterService: db.session.query(Tenant).delete() db.session.commit() - logging.exception(f'Setup failed: {e}') - raise ValueError(f'Setup failed: {e}') + logging.exception(f"Setup failed: {e}") + raise ValueError(f"Setup failed: {e}") @classmethod - def register(cls, email, name, - password: Optional[str] = None, - open_id: Optional[str] = None, - provider: Optional[str] = None, - language: Optional[str] = None, - status: Optional[AccountStatus] = None) -> Account: + def register( + cls, + email, + name, + password: Optional[str] = None, + open_id: Optional[str] = None, + provider: Optional[str] = None, + language: Optional[str] = None, + status: Optional[AccountStatus] = None, + ) -> Account: db.session.begin_nested() """Register account""" try: account = AccountService.create_account( - email=email, - name=name, - interface_language=language if language else languages[0], - password=password + email=email, name=name, interface_language=language if language else languages[0], password=password ) account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) if open_id is not None or provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if dify_config.EDITION != 'SELF_HOSTED': + if dify_config.EDITION != "SELF_HOSTED": tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role='owner') + TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) @@ -570,30 +559,29 @@ class RegisterService: db.session.commit() except Exception as e: db.session.rollback() - logging.error(f'Register failed: {e}') - raise AccountRegisterError(f'Registration failed: {e}') from e + logging.error(f"Register failed: {e}") + raise AccountRegisterError(f"Registration failed: {e}") from e return account @classmethod - def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str: + def invite_new_member( + cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None + ) -> str: """Invite new member""" account = Account.query.filter_by(email=email).first() if not account: - TenantService.check_member_permission(tenant, inviter, None, 'add') - name = email.split('@')[0] + TenantService.check_member_permission(tenant, inviter, None, "add") + name = email.split("@")[0] account = cls.register(email=email, name=name, language=language, status=AccountStatus.PENDING) # Create new tenant member for invited tenant TenantService.create_tenant_member(tenant, account, role) TenantService.switch_tenant(account, tenant.id) else: - TenantService.check_member_permission(tenant, inviter, account, 'add') - ta = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=account.id - ).first() + TenantService.check_member_permission(tenant, inviter, account, "add") + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() if not ta: TenantService.create_tenant_member(tenant, account, role) @@ -609,7 +597,7 @@ class RegisterService: language=account.interface_language, to=email, token=token, - inviter_name=inviter.name if inviter else 'Dify', + inviter_name=inviter.name if inviter else "Dify", workspace_name=tenant.name, ) @@ -619,23 +607,19 @@ class RegisterService: def generate_invite_token(cls, tenant: Tenant, account: Account) -> str: token = str(uuid.uuid4()) invitation_data = { - 'account_id': account.id, - 'email': account.email, - 'workspace_id': tenant.id, + "account_id": account.id, + "email": account.email, + "workspace_id": tenant.id, } expiryHours = dify_config.INVITE_EXPIRY_HOURS - redis_client.setex( - cls._get_invitation_token_key(token), - expiryHours * 60 * 60, - json.dumps(invitation_data) - ) + redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data)) return token @classmethod def revoke_token(cls, workspace_id: str, email: str, token: str): if workspace_id and email: email_hash = sha256(email.encode()).hexdigest() - cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token) + cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token) redis_client.delete(cache_key) else: redis_client.delete(cls._get_invitation_token_key(token)) @@ -646,17 +630,21 @@ class RegisterService: if not invitation_data: return None - tenant = db.session.query(Tenant).filter( - Tenant.id == invitation_data['workspace_id'], - Tenant.status == 'normal' - ).first() + tenant = ( + db.session.query(Tenant) + .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") + .first() + ) if not tenant: return None - tenant_account = db.session.query(Account, TenantAccountJoin.role).join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ).filter(Account.email == invitation_data['email'], TenantAccountJoin.tenant_id == tenant.id).first() + tenant_account = ( + db.session.query(Account, TenantAccountJoin.role) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) + .filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) + .first() + ) if not tenant_account: return None @@ -665,29 +653,29 @@ class RegisterService: if not account: return None - if invitation_data['account_id'] != str(account.id): + if invitation_data["account_id"] != str(account.id): return None return { - 'account': account, - 'data': invitation_data, - 'tenant': tenant, + "account": account, + "data": invitation_data, + "tenant": tenant, } @classmethod def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[dict[str, str]]: if workspace_id is not None and email is not None: email_hash = sha256(email.encode()).hexdigest() - cache_key = f'member_invite_token:{workspace_id}, {email_hash}:{token}' + cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" account_id = redis_client.get(cache_key) if not account_id: return None return { - 'account_id': account_id.decode('utf-8'), - 'email': email, - 'workspace_id': workspace_id, + "account_id": account_id.decode("utf-8"), + "email": email, + "workspace_id": workspace_id, } else: data = redis_client.get(cls._get_invitation_token_key(token)) diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 213df26222..d2cd7bea67 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,4 +1,3 @@ - import copy from core.prompt.prompt_templates.advanced_prompt_templates import ( @@ -17,59 +16,78 @@ from models.model import AppMode class AdvancedPromptTemplateService: - @classmethod def get_prompt(cls, args: dict) -> dict: - app_mode = args['app_mode'] - model_mode = args['model_mode'] - model_name = args['model_name'] - has_context = args['has_context'] + app_mode = args["app_mode"] + model_mode = args["model_mode"] + model_name = args["model_name"] + has_context = args["has_context"] - if 'baichuan' in model_name.lower(): + if "baichuan" in model_name.lower(): return cls.get_baichuan_prompt(app_mode, model_mode, has_context) else: return cls.get_common_prompt(app_mode, model_mode, has_context) @classmethod - def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: context_prompt = copy.deepcopy(CONTEXT) if app_mode == AppMode.CHAT.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) elif model_mode == "chat": return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) elif app_mode == AppMode.COMPLETION.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) - + return cls.get_chat_prompt( + copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt + ) + @classmethod def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: - if has_context == 'true': - prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text'] - + if has_context == "true": + prompt_template["completion_prompt_config"]["prompt"]["text"] = ( + context + prompt_template["completion_prompt_config"]["prompt"]["text"] + ) + return prompt_template @classmethod def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: - if has_context == 'true': - prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text'] - + if has_context == "true": + prompt_template["chat_prompt_config"]["prompt"][0]["text"] = ( + context + prompt_template["chat_prompt_config"]["prompt"][0]["text"] + ) + return prompt_template @classmethod - def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) if app_mode == AppMode.CHAT.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) elif app_mode == AppMode.COMPLETION.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), + has_context, + baichuan_context_prompt, + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) \ No newline at end of file + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) diff --git a/api/services/agent_service.py b/api/services/agent_service.py index ba5fd93326..887fb878b9 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -10,59 +10,65 @@ from models.model import App, Conversation, EndUser, Message, MessageAgentThough class AgentService: @classmethod - def get_agent_logs(cls, app_model: App, - conversation_id: str, - message_id: str) -> dict: + def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict: """ Service to get agent logs """ - conversation: Conversation = db.session.query(Conversation).filter( - Conversation.id == conversation_id, - Conversation.app_id == app_model.id, - ).first() + conversation: Conversation = ( + db.session.query(Conversation) + .filter( + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + ) + .first() + ) if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Message = db.session.query(Message).filter( - Message.id == message_id, - Message.conversation_id == conversation_id, - ).first() + message: Message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.conversation_id == conversation_id, + ) + .first() + ) if not message: raise ValueError(f"Message not found: {message_id}") - + agent_thoughts: list[MessageAgentThought] = message.agent_thoughts if conversation.from_end_user_id: # only select name field - executor = db.session.query(EndUser, EndUser.name).filter( - EndUser.id == conversation.from_end_user_id - ).first() + executor = ( + db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first() + ) else: - executor = db.session.query(Account, Account.name).filter( - Account.id == conversation.from_account_id - ).first() - + executor = ( + db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first() + ) + if executor: executor = executor.name else: - executor = 'Unknown' + executor = "Unknown" timezone = pytz.timezone(current_user.timezone) result = { - 'meta': { - 'status': 'success', - 'executor': executor, - 'start_time': message.created_at.astimezone(timezone).isoformat(), - 'elapsed_time': message.provider_response_latency, - 'total_tokens': message.answer_tokens + message.message_tokens, - 'agent_mode': app_model.app_model_config.agent_mode_dict.get('strategy', 'react'), - 'iterations': len(agent_thoughts), + "meta": { + "status": "success", + "executor": executor, + "start_time": message.created_at.astimezone(timezone).isoformat(), + "elapsed_time": message.provider_response_latency, + "total_tokens": message.answer_tokens + message.message_tokens, + "agent_mode": app_model.app_model_config.agent_mode_dict.get("strategy", "react"), + "iterations": len(agent_thoughts), }, - 'iterations': [], - 'files': message.files, + "iterations": [], + "files": message.files, } agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict()) @@ -86,12 +92,12 @@ class AgentService: tool_input = tool_inputs.get(tool_name, {}) tool_output = tool_outputs.get(tool_name, {}) tool_meta_data = tool_meta.get(tool_name, {}) - tool_config = tool_meta_data.get('tool_config', {}) - if tool_config.get('tool_provider_type', '') != 'dataset-retrieval': + tool_config = tool_meta_data.get("tool_config", {}) + if tool_config.get("tool_provider_type", "") != "dataset-retrieval": tool_icon = ToolManager.get_tool_icon( tenant_id=app_model.tenant_id, - provider_type=tool_config.get('tool_provider_type', ''), - provider_id=tool_config.get('tool_provider', ''), + provider_type=tool_config.get("tool_provider_type", ""), + provider_id=tool_config.get("tool_provider", ""), ) if not tool_icon: tool_entity = find_agent_tool(tool_name) @@ -102,30 +108,34 @@ class AgentService: provider_id=tool_entity.provider_id, ) else: - tool_icon = '' + tool_icon = "" - tool_calls.append({ - 'status': 'success' if not tool_meta_data.get('error') else 'error', - 'error': tool_meta_data.get('error'), - 'time_cost': tool_meta_data.get('time_cost', 0), - 'tool_name': tool_name, - 'tool_label': tool_label, - 'tool_input': tool_input, - 'tool_output': tool_output, - 'tool_parameters': tool_meta_data.get('tool_parameters', {}), - 'tool_icon': tool_icon, - }) + tool_calls.append( + { + "status": "success" if not tool_meta_data.get("error") else "error", + "error": tool_meta_data.get("error"), + "time_cost": tool_meta_data.get("time_cost", 0), + "tool_name": tool_name, + "tool_label": tool_label, + "tool_input": tool_input, + "tool_output": tool_output, + "tool_parameters": tool_meta_data.get("tool_parameters", {}), + "tool_icon": tool_icon, + } + ) - result['iterations'].append({ - 'tokens': agent_thought.tokens, - 'tool_calls': tool_calls, - 'tool_raw': { - 'inputs': agent_thought.tool_input, - 'outputs': agent_thought.observation, - }, - 'thought': agent_thought.thought, - 'created_at': agent_thought.created_at.isoformat(), - 'files': agent_thought.files, - }) + result["iterations"].append( + { + "tokens": agent_thought.tokens, + "tool_calls": tool_calls, + "tool_raw": { + "inputs": agent_thought.tool_input, + "outputs": agent_thought.observation, + }, + "thought": agent_thought.thought, + "created_at": agent_thought.created_at.isoformat(), + "files": agent_thought.files, + } + ) - return result \ No newline at end of file + return result diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index addcde44ed..3cc6c51c2d 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -23,21 +23,18 @@ class AppAnnotationService: @classmethod def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - if args.get('message_id'): - message_id = str(args['message_id']) + if args.get("message_id"): + message_id = str(args["message_id"]) # get message info - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first() if not message: raise NotFound("Message Not Exists.") @@ -45,159 +42,166 @@ class AppAnnotationService: annotation = message.annotation # save the message annotation if annotation: - annotation.content = args['answer'] - annotation.question = args['question'] + annotation.content = args["answer"] + annotation.question = args["question"] else: annotation = MessageAnnotation( app_id=app.id, conversation_id=message.conversation_id, message_id=message.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + content=args["answer"], + question=args["question"], + account_id=current_user.id, ) else: annotation = MessageAnnotation( - app_id=app.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id ) db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: - add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, - app_id, annotation_setting.collection_binding_id) + add_annotation_to_index_task.delay( + annotation.id, + args["question"], + current_user.current_tenant_id, + app_id, + annotation_setting.collection_binding_id, + ) return annotation @classmethod def enable_app_annotation(cls, args: dict, app_id: str) -> dict: - enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) + enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) cache_result = redis_client.get(enable_app_annotation_key) if cache_result is not None: - return { - 'job_id': cache_result, - 'job_status': 'processing' - } + return {"job_id": cache_result, "job_status": "processing"} # async job job_id = str(uuid.uuid4()) - enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) + enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(enable_app_annotation_job_key, 'waiting') - enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id, - args['score_threshold'], - args['embedding_provider_name'], args['embedding_model_name']) - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + redis_client.setnx(enable_app_annotation_job_key, "waiting") + enable_annotation_reply_task.delay( + str(job_id), + app_id, + current_user.id, + current_user.current_tenant_id, + args["score_threshold"], + args["embedding_provider_name"], + args["embedding_model_name"], + ) + return {"job_id": job_id, "job_status": "waiting"} @classmethod def disable_app_annotation(cls, app_id: str) -> dict: - disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) + disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) cache_result = redis_client.get(disable_app_annotation_key) if cache_result is not None: - return { - 'job_id': cache_result, - 'job_status': 'processing' - } + return {"job_id": cache_result, "job_status": "processing"} # async job job_id = str(uuid.uuid4()) - disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) + disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(disable_app_annotation_job_key, 'waiting') + redis_client.setnx(disable_app_annotation_job_key, "waiting") disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id) - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") if keyword: - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .filter( - or_( - MessageAnnotation.question.ilike('%{}%'.format(keyword)), - MessageAnnotation.content.ilike('%{}%'.format(keyword)) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .filter( + or_( + MessageAnnotation.question.ilike("%{}%".format(keyword)), + MessageAnnotation.content.ilike("%{}%".format(keyword)), + ) ) + .order_by(MessageAnnotation.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) - .order_by(MessageAnnotation.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) else: - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .order_by(MessageAnnotation.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) return annotations.items, annotations.total @classmethod def export_annotation_list_by_app_id(cls, app_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .order_by(MessageAnnotation.created_at.desc()).all()) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc()) + .all() + ) return annotations @classmethod def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") annotation = MessageAnnotation( - app_id=app.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id ) db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: - add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, - app_id, annotation_setting.collection_binding_id) + add_annotation_to_index_task.delay( + annotation.id, + args["question"], + current_user.current_tenant_id, + app_id, + annotation_setting.collection_binding_id, + ) return annotation @classmethod def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -207,30 +211,34 @@ class AppAnnotationService: if not annotation: raise NotFound("Annotation not found") - annotation.content = args['answer'] - annotation.question = args['question'] + annotation.content = args["answer"] + annotation.question = args["question"] db.session.commit() # if annotation reply is enabled , add annotation to index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - update_annotation_to_index_task.delay(annotation.id, annotation.question, - current_user.current_tenant_id, - app_id, app_annotation_setting.collection_binding_id) + update_annotation_to_index_task.delay( + annotation.id, + annotation.question, + current_user.current_tenant_id, + app_id, + app_annotation_setting.collection_binding_id, + ) return annotation @classmethod def delete_app_annotation(cls, app_id: str, annotation_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -242,33 +250,34 @@ class AppAnnotationService: db.session.delete(annotation) - annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.annotation_id == annotation_id) - .all() - ) + annotation_hit_histories = ( + db.session.query(AppAnnotationHitHistory) + .filter(AppAnnotationHitHistory.annotation_id == annotation_id) + .all() + ) if annotation_hit_histories: for annotation_hit_history in annotation_hit_histories: db.session.delete(annotation_hit_history) db.session.commit() # if annotation reply is enabled , delete annotation index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - delete_annotation_index_task.delay(annotation.id, app_id, - current_user.current_tenant_id, - app_annotation_setting.collection_binding_id) + delete_annotation_index_task.delay( + annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id + ) @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -278,10 +287,7 @@ class AppAnnotationService: df = pd.read_csv(file) result = [] for index, row in df.iterrows(): - content = { - 'question': row[0], - 'answer': row[1] - } + content = {"question": row[0], "answer": row[1]} result.append(content) if len(result) == 0: raise ValueError("The CSV file is empty.") @@ -293,28 +299,24 @@ class AppAnnotationService: raise ValueError("The number of annotations exceeds the limit of your subscription.") # async job job_id = str(uuid.uuid4()) - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(indexing_cache_key, 'waiting') - batch_import_annotations_task.delay(str(job_id), result, app_id, - current_user.current_tenant_id, current_user.id) + redis_client.setnx(indexing_cache_key, "waiting") + batch_import_annotations_task.delay( + str(job_id), result, app_id, current_user.current_tenant_id, current_user.id + ) except Exception as e: - return { - 'error_msg': str(e) - } - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + return {"error_msg": str(e)} + return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -324,12 +326,15 @@ class AppAnnotationService: if not annotation: raise NotFound("Annotation not found") - annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.app_id == app_id, - AppAnnotationHitHistory.annotation_id == annotation_id, - ) - .order_by(AppAnnotationHitHistory.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) + annotation_hit_histories = ( + db.session.query(AppAnnotationHitHistory) + .filter( + AppAnnotationHitHistory.app_id == app_id, + AppAnnotationHitHistory.annotation_id == annotation_id, + ) + .order_by(AppAnnotationHitHistory.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) return annotation_hit_histories.items, annotation_hit_histories.total @classmethod @@ -341,15 +346,21 @@ class AppAnnotationService: return annotation @classmethod - def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str, - annotation_content: str, query: str, user_id: str, - message_id: str, from_source: str, score: float): + def add_annotation_history( + cls, + annotation_id: str, + app_id: str, + annotation_question: str, + annotation_content: str, + query: str, + user_id: str, + message_id: str, + from_source: str, + score: float, + ): # add hit count to annotation - db.session.query(MessageAnnotation).filter( - MessageAnnotation.id == annotation_id - ).update( - {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, - synchronize_session=False + db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update( + {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False ) annotation_hit_history = AppAnnotationHitHistory( @@ -361,7 +372,7 @@ class AppAnnotationService: score=score, message_id=message_id, annotation_question=annotation_question, - annotation_content=annotation_content + annotation_content=annotation_content, ) db.session.add(annotation_hit_history) db.session.commit() @@ -369,17 +380,18 @@ class AppAnnotationService: @classmethod def get_app_annotation_setting_by_app_id(cls, app_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail return { @@ -388,32 +400,34 @@ class AppAnnotationService: "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } - return { - "enabled": False - } + return {"enabled": False} @classmethod def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id, - AppAnnotationSetting.id == annotation_setting_id, - ).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting) + .filter( + AppAnnotationSetting.app_id == app_id, + AppAnnotationSetting.id == annotation_setting_id, + ) + .first() + ) if not annotation_setting: raise NotFound("App annotation not found") - annotation_setting.score_threshold = args['score_threshold'] + annotation_setting.score_threshold = args["score_threshold"] annotation_setting.updated_user_id = current_user.id annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(annotation_setting) @@ -427,6 +441,6 @@ class AppAnnotationService: "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 8441bbedb3..601d67d2fb 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -5,13 +5,14 @@ from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint class APIBasedExtensionService: - @staticmethod def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: - extension_list = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=tenant_id) \ - .order_by(APIBasedExtension.created_at.desc()) \ - .all() + extension_list = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=tenant_id) + .order_by(APIBasedExtension.created_at.desc()) + .all() + ) for extension in extension_list: extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) @@ -35,10 +36,12 @@ class APIBasedExtensionService: @staticmethod def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=tenant_id) \ - .filter_by(id=api_based_extension_id) \ + extension = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=tenant_id) + .filter_by(id=api_based_extension_id) .first() + ) if not extension: raise ValueError("API based extension is not found") @@ -55,20 +58,24 @@ class APIBasedExtensionService: if not extension_data.id: # case one: check new data, name must be unique - is_name_existed = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=extension_data.tenant_id) \ - .filter_by(name=extension_data.name) \ + is_name_existed = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=extension_data.tenant_id) + .filter_by(name=extension_data.name) .first() + ) if is_name_existed: raise ValueError("name must be unique, it is already existed") else: # case two: check existing data, name must be unique - is_name_existed = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=extension_data.tenant_id) \ - .filter_by(name=extension_data.name) \ - .filter(APIBasedExtension.id != extension_data.id) \ + is_name_existed = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=extension_data.tenant_id) + .filter_by(name=extension_data.name) + .filter(APIBasedExtension.id != extension_data.id) .first() + ) if is_name_existed: raise ValueError("name must be unique, it is already existed") @@ -92,7 +99,7 @@ class APIBasedExtensionService: try: client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key) resp = client.request(point=APIBasedExtensionPoint.PING, params={}) - if resp.get('result') != 'pong': + if resp.get("result") != "pong": raise ValueError(resp) except Exception as e: raise ValueError("connection error: {}".format(e)) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 737def3366..a938b4f93b 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -75,43 +75,44 @@ class AppDslService: # check or repair dsl version import_data = cls._check_or_fix_dsl(import_data) - app_data = import_data.get('app') + app_data = import_data.get("app") if not app_data: raise ValueError("Missing app in data argument") # get app basic info - name = args.get("name") if args.get("name") else app_data.get('name') - description = args.get("description") if args.get("description") else app_data.get('description', '') - icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get('icon_type') - icon = args.get("icon") if args.get("icon") else app_data.get('icon') - icon_background = args.get("icon_background") if args.get("icon_background") \ - else app_data.get('icon_background') + name = args.get("name") if args.get("name") else app_data.get("name") + description = args.get("description") if args.get("description") else app_data.get("description", "") + icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type") + icon = args.get("icon") if args.get("icon") else app_data.get("icon") + icon_background = ( + args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background") + ) # import dsl and create app - app_mode = AppMode.value_of(app_data.get('mode')) + app_mode = AppMode.value_of(app_data.get("mode")) if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: app = cls._import_and_create_new_workflow_based_app( tenant_id=tenant_id, app_mode=app_mode, - workflow_data=import_data.get('workflow'), + workflow_data=import_data.get("workflow"), account=account, name=name, description=description, icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, ) elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]: app = cls._import_and_create_new_model_config_based_app( tenant_id=tenant_id, app_mode=app_mode, - model_config_data=import_data.get('model_config'), + model_config_data=import_data.get("model_config"), account=account, name=name, description=description, icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, ) else: raise ValueError("Invalid app mode") @@ -134,27 +135,26 @@ class AppDslService: # check or repair dsl version import_data = cls._check_or_fix_dsl(import_data) - app_data = import_data.get('app') + app_data = import_data.get("app") if not app_data: raise ValueError("Missing app in data argument") # import dsl and overwrite app - app_mode = AppMode.value_of(app_data.get('mode')) + app_mode = AppMode.value_of(app_data.get("mode")) if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: raise ValueError("Only support import workflow in advanced-chat or workflow app.") - if app_data.get('mode') != app_model.mode: - raise ValueError( - f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") + if app_data.get("mode") != app_model.mode: + raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") return cls._import_and_overwrite_workflow_based_app( app_model=app_model, - workflow_data=import_data.get('workflow'), + workflow_data=import_data.get("workflow"), account=account, ) @classmethod - def export_dsl(cls, app_model: App, include_secret:bool = False) -> str: + def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: """ Export app :param app_model: App instance @@ -168,14 +168,16 @@ class AppDslService: "app": { "name": app_model.name, "mode": app_model.mode, - "icon": '🤖' if app_model.icon_type == 'image' else app_model.icon, - "icon_background": '#FFEAD5' if app_model.icon_type == 'image' else app_model.icon_background, - "description": app_model.description - } + "icon": "🤖" if app_model.icon_type == "image" else app_model.icon, + "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, + "description": app_model.description, + }, } if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: - cls._append_workflow_export_data(export_data=export_data, app_model=app_model, include_secret=include_secret) + cls._append_workflow_export_data( + export_data=export_data, app_model=app_model, include_secret=include_secret + ) else: cls._append_model_config_export_data(export_data, app_model) @@ -188,31 +190,35 @@ class AppDslService: :param import_data: import data """ - if not import_data.get('version'): - import_data['version'] = "0.1.0" + if not import_data.get("version"): + import_data["version"] = "0.1.0" - if not import_data.get('kind') or import_data.get('kind') != "app": - import_data['kind'] = "app" + if not import_data.get("kind") or import_data.get("kind") != "app": + import_data["kind"] = "app" - if import_data.get('version') != current_dsl_version: + if import_data.get("version") != current_dsl_version: # Currently only one DSL version, so no difference checks or compatibility fixes will be performed. - logger.warning(f"DSL version {import_data.get('version')} is not compatible " - f"with current version {current_dsl_version}, related to " - f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}.") + logger.warning( + f"DSL version {import_data.get('version')} is not compatible " + f"with current version {current_dsl_version}, related to " + f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}." + ) return import_data @classmethod - def _import_and_create_new_workflow_based_app(cls, - tenant_id: str, - app_mode: AppMode, - workflow_data: dict, - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str) -> App: + def _import_and_create_new_workflow_based_app( + cls, + tenant_id: str, + app_mode: AppMode, + workflow_data: dict, + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + ) -> App: """ Import app dsl and create new workflow based app @@ -227,8 +233,7 @@ class AppDslService: :param icon_background: app icon background """ if not workflow_data: - raise ValueError("Missing workflow in data argument " - "when app mode is advanced-chat or workflow") + raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow") app = cls._create_app( tenant_id=tenant_id, @@ -238,37 +243,32 @@ class AppDslService: description=description, icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, ) # init draft workflow - environment_variables_list = workflow_data.get('environment_variables') or [] + environment_variables_list = workflow_data.get("environment_variables") or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] - conversation_variables_list = workflow_data.get('conversation_variables') or [] + conversation_variables_list = workflow_data.get("conversation_variables") or [] conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] workflow_service = WorkflowService() draft_workflow = workflow_service.sync_draft_workflow( app_model=app, - graph=workflow_data.get('graph', {}), - features=workflow_data.get('../core/app/features', {}), + graph=workflow_data.get("graph", {}), + features=workflow_data.get("../core/app/features", {}), unique_hash=None, account=account, environment_variables=environment_variables, conversation_variables=conversation_variables, ) - workflow_service.publish_workflow( - app_model=app, - account=account, - draft_workflow=draft_workflow - ) + workflow_service.publish_workflow(app_model=app, account=account, draft_workflow=draft_workflow) return app @classmethod - def _import_and_overwrite_workflow_based_app(cls, - app_model: App, - workflow_data: dict, - account: Account) -> Workflow: + def _import_and_overwrite_workflow_based_app( + cls, app_model: App, workflow_data: dict, account: Account + ) -> Workflow: """ Import app dsl and overwrite workflow based app @@ -277,8 +277,7 @@ class AppDslService: :param account: Account instance """ if not workflow_data: - raise ValueError("Missing workflow in data argument " - "when app mode is advanced-chat or workflow") + raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow") # fetch draft workflow by app_model workflow_service = WorkflowService() @@ -289,14 +288,14 @@ class AppDslService: unique_hash = None # sync draft workflow - environment_variables_list = workflow_data.get('environment_variables') or [] + environment_variables_list = workflow_data.get("environment_variables") or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] - conversation_variables_list = workflow_data.get('conversation_variables') or [] + conversation_variables_list = workflow_data.get("conversation_variables") or [] conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] draft_workflow = workflow_service.sync_draft_workflow( app_model=app_model, - graph=workflow_data.get('graph', {}), - features=workflow_data.get('features', {}), + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), unique_hash=unique_hash, account=account, environment_variables=environment_variables, @@ -306,16 +305,18 @@ class AppDslService: return draft_workflow @classmethod - def _import_and_create_new_model_config_based_app(cls, - tenant_id: str, - app_mode: AppMode, - model_config_data: dict, - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str) -> App: + def _import_and_create_new_model_config_based_app( + cls, + tenant_id: str, + app_mode: AppMode, + model_config_data: dict, + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + ) -> App: """ Import app dsl and create new model config based app @@ -329,8 +330,7 @@ class AppDslService: :param icon_background: app icon background """ if not model_config_data: - raise ValueError("Missing model_config in data argument " - "when app mode is chat, agent-chat or completion") + raise ValueError("Missing model_config in data argument " "when app mode is chat, agent-chat or completion") app = cls._create_app( tenant_id=tenant_id, @@ -340,7 +340,7 @@ class AppDslService: description=description, icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, ) app_model_config = AppModelConfig() @@ -352,23 +352,22 @@ class AppDslService: app.app_model_config_id = app_model_config.id - app_model_config_was_updated.send( - app, - app_model_config=app_model_config - ) + app_model_config_was_updated.send(app, app_model_config=app_model_config) return app @classmethod - def _create_app(cls, - tenant_id: str, - app_mode: AppMode, - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str) -> App: + def _create_app( + cls, + tenant_id: str, + app_mode: AppMode, + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + ) -> App: """ Create new app @@ -390,7 +389,7 @@ class AppDslService: icon=icon, icon_background=icon_background, enable_site=True, - enable_api=True + enable_api=True, ) db.session.add(app) @@ -412,7 +411,7 @@ class AppDslService: if not workflow: raise ValueError("Missing draft workflow configuration, please check.") - export_data['workflow'] = workflow.to_dict(include_secret=include_secret) + export_data["workflow"] = workflow.to_dict(include_secret=include_secret) @classmethod def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: @@ -425,4 +424,4 @@ class AppDslService: if not app_model_config: raise ValueError("Missing app configuration, please check.") - export_data['model_config'] = app_model_config.to_dict() + export_data["model_config"] = app_model_config.to_dict() diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 436274daf7..17de1142fb 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -15,14 +15,15 @@ from services.workflow_service import WorkflowService class AppGenerateService: - @classmethod - def generate(cls, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - streaming: bool = True, - ): + def generate( + cls, + app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + streaming: bool = True, + ): """ App Content Generate :param app_model: app model @@ -38,51 +39,54 @@ class AppGenerateService: try: request_id = rate_limit.enter(request_id) if app_model.mode == AppMode.COMPLETION.value: - return rate_limit.generate(CompletionAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + CompletionAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: - return rate_limit.generate(AgentChatAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + AgentChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.CHAT.value: - return rate_limit.generate(ChatAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + ChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, invoke_from) - return rate_limit.generate(AdvancedChatAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + AdvancedChatAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming, + ), + request_id, + ) elif app_model.mode == AppMode.WORKFLOW.value: workflow = cls._get_workflow(app_model, invoke_from) - return rate_limit.generate(WorkflowAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + WorkflowAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming, + ), + request_id, + ) else: - raise ValueError(f'Invalid app mode {app_model.mode}') + raise ValueError(f"Invalid app mode {app_model.mode}") finally: if not streaming: rate_limit.exit(request_id) @@ -95,38 +99,31 @@ class AppGenerateService: return max_active_requests @classmethod - def generate_single_iteration(cls, app_model: App, - user: Account, - node_id: str, - args: Any, - streaming: bool = True): + def generate_single_iteration( + cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True + ): if app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator().single_iteration_generate( - app_model=app_model, - workflow=workflow, - node_id=node_id, - user=user, - args=args, - stream=streaming + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming ) elif app_model.mode == AppMode.WORKFLOW.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return WorkflowAppGenerator().single_iteration_generate( - app_model=app_model, - workflow=workflow, - node_id=node_id, - user=user, - args=args, - stream=streaming + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming ) else: - raise ValueError(f'Invalid app mode {app_model.mode}') + raise ValueError(f"Invalid app mode {app_model.mode}") @classmethod - def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], - message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ - -> Union[dict, Generator]: + def generate_more_like_this( + cls, + app_model: App, + user: Union[Account, EndUser], + message_id: str, + invoke_from: InvokeFrom, + streaming: bool = True, + ) -> Union[dict, Generator]: """ Generate more like this :param app_model: app model @@ -137,11 +134,7 @@ class AppGenerateService: :return: """ return CompletionAppGenerator().generate_more_like_this( - app_model=app_model, - message_id=message_id, - user=user, - invoke_from=invoke_from, - stream=streaming + app_model=app_model, message_id=message_id, user=user, invoke_from=invoke_from, stream=streaming ) @classmethod @@ -158,12 +151,12 @@ class AppGenerateService: workflow = workflow_service.get_draft_workflow(app_model=app_model) if not workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") else: # fetch published workflow by app_model workflow = workflow_service.get_published_workflow(app_model=app_model) if not workflow: - raise ValueError('Workflow not published') + raise ValueError("Workflow not published") return workflow diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index c84f6fbf45..a1ad271053 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -5,7 +5,6 @@ from models.model import AppMode class AppModelConfigService: - @classmethod def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: if app_mode == AppMode.CHAT: diff --git a/api/services/app_service.py b/api/services/app_service.py index 93f7169c1f..8a2f8c0535 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -33,27 +33,22 @@ class AppService: :param args: request args :return: """ - filters = [ - App.tenant_id == tenant_id, - App.is_universal == False - ] + filters = [App.tenant_id == tenant_id, App.is_universal == False] - if args['mode'] == 'workflow': + if args["mode"] == "workflow": filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) - elif args['mode'] == 'chat': + elif args["mode"] == "chat": filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) - elif args['mode'] == 'agent-chat': + elif args["mode"] == "agent-chat": filters.append(App.mode == AppMode.AGENT_CHAT.value) - elif args['mode'] == 'channel': + elif args["mode"] == "channel": filters.append(App.mode == AppMode.CHANNEL.value) - if args.get('name'): - name = args['name'][:30] - filters.append(App.name.ilike(f'%{name}%')) - if args.get('tag_ids'): - target_ids = TagService.get_target_ids_by_tag_ids('app', - tenant_id, - args['tag_ids']) + if args.get("name"): + name = args["name"][:30] + filters.append(App.name.ilike(f"%{name}%")) + if args.get("tag_ids"): + target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) if target_ids: filters.append(App.id.in_(target_ids)) else: @@ -61,9 +56,9 @@ class AppService: app_models = db.paginate( db.select(App).where(*filters).order_by(App.created_at.desc()), - page=args['page'], - per_page=args['limit'], - error_out=False + page=args["page"], + per_page=args["limit"], + error_out=False, ) return app_models @@ -75,21 +70,20 @@ class AppService: :param args: request args :param account: Account instance """ - app_mode = AppMode.value_of(args['mode']) + app_mode = AppMode.value_of(args["mode"]) app_template = default_app_templates[app_mode] # get model config - default_model_config = app_template.get('model_config') + default_model_config = app_template.get("model_config") default_model_config = default_model_config.copy() if default_model_config else None - if default_model_config and 'model' in default_model_config: + if default_model_config and "model" in default_model_config: # get model provider model_manager = ModelManager() # get default model instance try: model_instance = model_manager.get_default_model_instance( - tenant_id=account.current_tenant_id, - model_type=ModelType.LLM + tenant_id=account.current_tenant_id, model_type=ModelType.LLM ) except (ProviderTokenNotInitError, LLMBadRequestError): model_instance = None @@ -98,39 +92,41 @@ class AppService: model_instance = None if model_instance: - if model_instance.model == default_model_config['model']['name'] and model_instance.provider == default_model_config['model']['provider']: - default_model_dict = default_model_config['model'] + if ( + model_instance.model == default_model_config["model"]["name"] + and model_instance.provider == default_model_config["model"]["provider"] + ): + default_model_dict = default_model_config["model"] else: llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) default_model_dict = { - 'provider': model_instance.provider, - 'name': model_instance.model, - 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), - 'completion_params': {} + "provider": model_instance.provider, + "name": model_instance.model, + "mode": model_schema.model_properties.get(ModelPropertyKey.MODE), + "completion_params": {}, } else: provider, model = model_manager.get_default_provider_model_name( - tenant_id=account.current_tenant_id, - model_type=ModelType.LLM + tenant_id=account.current_tenant_id, model_type=ModelType.LLM ) - default_model_config['model']['provider'] = provider - default_model_config['model']['name'] = model - default_model_dict = default_model_config['model'] + default_model_config["model"]["provider"] = provider + default_model_config["model"]["name"] = model + default_model_dict = default_model_config["model"] - default_model_config['model'] = json.dumps(default_model_dict) + default_model_config["model"] = json.dumps(default_model_dict) - app = App(**app_template['app']) - app.name = args['name'] - app.description = args.get('description', '') - app.mode = args['mode'] - app.icon_type = args.get('icon_type', 'emoji') - app.icon = args['icon'] - app.icon_background = args['icon_background'] + app = App(**app_template["app"]) + app.name = args["name"] + app.description = args.get("description", "") + app.mode = args["mode"] + app.icon_type = args.get("icon_type", "emoji") + app.icon = args["icon"] + app.icon_background = args["icon_background"] app.tenant_id = tenant_id - app.api_rph = args.get('api_rph', 0) - app.api_rpm = args.get('api_rpm', 0) + app.api_rph = args.get("api_rph", 0) + app.api_rpm = args.get("api_rpm", 0) db.session.add(app) db.session.flush() @@ -158,7 +154,7 @@ class AppService: model_config: AppModelConfig = app.app_model_config agent_mode = model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue agent_tool_entity = AgentToolEntity(**tool) @@ -174,7 +170,7 @@ class AppService: tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app.id}' + identity_id=f"AGENT.{app.id}", ) # get decrypted parameters @@ -185,7 +181,7 @@ class AppService: masked_parameter = {} # override tool parameters - tool['tool_parameters'] = masked_parameter + tool["tool_parameters"] = masked_parameter except Exception as e: pass @@ -215,12 +211,12 @@ class AppService: :param args: request args :return: App instance """ - app.name = args.get('name') - app.description = args.get('description', '') - app.max_active_requests = args.get('max_active_requests') - app.icon_type = args.get('icon_type', 'emoji') - app.icon = args.get('icon') - app.icon_background = args.get('icon_background') + app.name = args.get("name") + app.description = args.get("description", "") + app.max_active_requests = args.get("max_active_requests") + app.icon_type = args.get("icon_type", "emoji") + app.icon = args.get("icon") + app.icon_background = args.get("icon_background") app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -298,10 +294,7 @@ class AppService: db.session.commit() # Trigger asynchronous deletion of app and related data - remove_app_and_related_data_task.delay( - tenant_id=app.tenant_id, - app_id=app.id - ) + remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) def get_app_meta(self, app_model: App) -> dict: """ @@ -311,9 +304,7 @@ class AppService: """ app_mode = AppMode.value_of(app_model.mode) - meta = { - 'tool_icons': {} - } + meta = {"tool_icons": {}} if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: workflow = app_model.workflow @@ -321,17 +312,19 @@ class AppService: return meta graph = workflow.graph_dict - nodes = graph.get('nodes', []) + nodes = graph.get("nodes", []) tools = [] for node in nodes: - if node.get('data', {}).get('type') == 'tool': - node_data = node.get('data', {}) - tools.append({ - 'provider_type': node_data.get('provider_type'), - 'provider_id': node_data.get('provider_id'), - 'tool_name': node_data.get('tool_name'), - 'tool_parameters': {} - }) + if node.get("data", {}).get("type") == "tool": + node_data = node.get("data", {}) + tools.append( + { + "provider_type": node_data.get("provider_type"), + "provider_id": node_data.get("provider_id"), + "tool_name": node_data.get("tool_name"), + "tool_parameters": {}, + } + ) else: app_model_config: AppModelConfig = app_model.app_model_config @@ -341,30 +334,26 @@ class AppService: agent_config = app_model_config.agent_mode_dict or {} # get all tools - tools = agent_config.get('tools', []) + tools = agent_config.get("tools", []) - url_prefix = (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/builtin/") + url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: # current tool standard - provider_type = tool.get('provider_type') - provider_id = tool.get('provider_id') - tool_name = tool.get('tool_name') - if provider_type == 'builtin': - meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon' - elif provider_type == 'api': + provider_type = tool.get("provider_type") + provider_id = tool.get("provider_id") + tool_name = tool.get("tool_name") + if provider_type == "builtin": + meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" + elif provider_type == "api": try: - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.id == provider_id - ).first() - meta['tool_icons'][tool_name] = json.loads(provider.icon) + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() + ) + meta["tool_icons"][tool_name] = json.loads(provider.icon) except: - meta['tool_icons'][tool_name] = { - "background": "#252525", - "content": "\ud83d\ude01" - } + meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} return meta diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 58c950816f..05cd1c96a1 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -17,7 +17,7 @@ from services.errors.audio import ( FILE_SIZE = 30 FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024 -ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm', 'amr'] +ALLOWED_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "amr"] logger = logging.getLogger(__name__) @@ -31,19 +31,19 @@ class AudioService: raise ValueError("Speech to text is not enabled") features_dict = workflow.features_dict - if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'): + if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): raise ValueError("Speech to text is not enabled") else: app_model_config: AppModelConfig = app_model.app_model_config - if not app_model_config.speech_to_text_dict['enabled']: + if not app_model_config.speech_to_text_dict["enabled"]: raise ValueError("Speech to text is not enabled") if file is None: raise NoAudioUploadedServiceError() extension = file.mimetype - if extension not in [f'audio/{ext}' for ext in ALLOWED_EXTENSIONS]: + if extension not in [f"audio/{ext}" for ext in ALLOWED_EXTENSIONS]: raise UnsupportedAudioTypeServiceError() file_content = file.read() @@ -55,20 +55,25 @@ class AudioService: model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.SPEECH2TEXT + tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) if model_instance is None: raise ProviderNotSupportSpeechToTextServiceError() buffer = io.BytesIO(file_content) - buffer.name = 'temp.mp3' + buffer.name = "temp.mp3" return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} @classmethod - def transcript_tts(cls, app_model: App, text: Optional[str] = None, - voice: Optional[str] = None, end_user: Optional[str] = None, message_id: Optional[str] = None): + def transcript_tts( + cls, + app_model: App, + text: Optional[str] = None, + voice: Optional[str] = None, + end_user: Optional[str] = None, + message_id: Optional[str] = None, + ): from collections.abc import Generator from flask import Response, stream_with_context @@ -84,65 +89,56 @@ class AudioService: raise ValueError("TTS is not enabled") features_dict = workflow.features_dict - if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'): + if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"): raise ValueError("TTS is not enabled") - voice = features_dict['text_to_speech'].get('voice') if voice is None else voice + voice = features_dict["text_to_speech"].get("voice") if voice is None else voice else: text_to_speech_dict = app_model.app_model_config.text_to_speech_dict - if not text_to_speech_dict.get('enabled'): + if not text_to_speech_dict.get("enabled"): raise ValueError("TTS is not enabled") - voice = text_to_speech_dict.get('voice') if voice is None else voice + voice = text_to_speech_dict.get("voice") if voice is None else voice model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.TTS + tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) try: if not voice: voices = model_instance.get_tts_voices() if voices: - voice = voices[0].get('value') + voice = voices[0].get("value") else: raise ValueError("Sorry, no voice available.") return model_instance.invoke_tts( - content_text=text_content.strip(), - user=end_user, - tenant_id=app_model.tenant_id, - voice=voice + content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice ) except Exception as e: raise e if message_id: - message = db.session.query(Message).filter( - Message.id == message_id - ).first() - if message.answer == '' and message.status == 'normal': + message = db.session.query(Message).filter(Message.id == message_id).first() + if message.answer == "" and message.status == "normal": return None else: response = invoke_tts(message.answer, app_model=app_model, voice=voice) if isinstance(response, Generator): - return Response(stream_with_context(response), content_type='audio/mpeg') + return Response(stream_with_context(response), content_type="audio/mpeg") return response else: response = invoke_tts(text, app_model, voice) if isinstance(response, Generator): - return Response(stream_with_context(response), content_type='audio/mpeg') + return Response(stream_with_context(response), content_type="audio/mpeg") return response @classmethod def transcript_tts_voices(cls, tenant_id: str, language: str): model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TTS - ) + model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index ccd0023c44..ae5b953b47 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,14 +1,12 @@ - from services.auth.firecrawl import FirecrawlAuth class ApiKeyAuthFactory: - def __init__(self, provider: str, credentials: dict): - if provider == 'firecrawl': + if provider == "firecrawl": self.auth = FirecrawlAuth(credentials) else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") def validate_credentials(self): return self.auth.validate_credentials() diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 43d0fbf98f..e5f4a3ef6e 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -7,39 +7,43 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory class ApiKeyAuthService: - @staticmethod def get_provider_auth_list(tenant_id: str) -> list: - data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.disabled.is_(False) - ).all() + data_source_api_key_bindings = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) + .all() + ) return data_source_api_key_bindings @staticmethod def create_provider_auth(tenant_id: str, args: dict): - auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials() + auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials() if auth_result: # Encrypt the api key - api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key']) - args['credentials']['config']['api_key'] = api_key + api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"]) + args["credentials"]["config"]["api_key"] = api_key data_source_api_key_binding = DataSourceApiKeyAuthBinding() data_source_api_key_binding.tenant_id = tenant_id - data_source_api_key_binding.category = args['category'] - data_source_api_key_binding.provider = args['provider'] - data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False) + data_source_api_key_binding.category = args["category"] + data_source_api_key_binding.provider = args["provider"] + data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False) db.session.add(data_source_api_key_binding) db.session.commit() @staticmethod def get_auth_credentials(tenant_id: str, category: str, provider: str): - data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.category == category, - DataSourceApiKeyAuthBinding.provider == provider, - DataSourceApiKeyAuthBinding.disabled.is_(False) - ).first() + data_source_api_key_bindings = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, + DataSourceApiKeyAuthBinding.category == category, + DataSourceApiKeyAuthBinding.provider == provider, + DataSourceApiKeyAuthBinding.disabled.is_(False), + ) + .first() + ) if not data_source_api_key_bindings: return None credentials = json.loads(data_source_api_key_bindings.credentials) @@ -47,24 +51,24 @@ class ApiKeyAuthService: @staticmethod def delete_provider_auth(tenant_id: str, binding_id: str): - data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.id == binding_id - ).first() + data_source_api_key_binding = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) + .first() + ) if data_source_api_key_binding: db.session.delete(data_source_api_key_binding) db.session.commit() @classmethod def validate_api_key_auth_args(cls, args): - if 'category' not in args or not args['category']: - raise ValueError('category is required') - if 'provider' not in args or not args['provider']: - raise ValueError('provider is required') - if 'credentials' not in args or not args['credentials']: - raise ValueError('credentials is required') - if not isinstance(args['credentials'], dict): - raise ValueError('credentials must be a dictionary') - if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']: - raise ValueError('auth_type is required') - + if "category" not in args or not args["category"]: + raise ValueError("category is required") + if "provider" not in args or not args["provider"]: + raise ValueError("provider is required") + if "credentials" not in args or not args["credentials"]: + raise ValueError("credentials is required") + if not isinstance(args["credentials"], dict): + raise ValueError("credentials must be a dictionary") + if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]: + raise ValueError("auth_type is required") diff --git a/api/services/auth/firecrawl.py b/api/services/auth/firecrawl.py index 69e3fb43c7..30e4ee57c0 100644 --- a/api/services/auth/firecrawl.py +++ b/api/services/auth/firecrawl.py @@ -8,49 +8,40 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase class FirecrawlAuth(ApiKeyAuthBase): def __init__(self, credentials: dict): super().__init__(credentials) - auth_type = credentials.get('auth_type') - if auth_type != 'bearer': - raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer') - self.api_key = credentials.get('config').get('api_key', None) - self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev') + auth_type = credentials.get("auth_type") + if auth_type != "bearer": + raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer") + self.api_key = credentials.get("config").get("api_key", None) + self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev") if not self.api_key: - raise ValueError('No API key provided') + raise ValueError("No API key provided") def validate_credentials(self): headers = self._prepare_headers() options = { - 'url': 'https://example.com', - 'crawlerOptions': { - 'excludes': [], - 'includes': [], - 'limit': 1 - }, - 'pageOptions': { - 'onlyMainContent': True - } + "url": "https://example.com", + "crawlerOptions": {"excludes": [], "includes": [], "limit": 1}, + "pageOptions": {"onlyMainContent": True}, } - response = self._post_request(f'{self.base_url}/v0/crawl', options, headers) + response = self._post_request(f"{self.base_url}/v0/crawl", options, headers) if response.status_code == 200: return True else: self._handle_error(response) def _prepare_headers(self): - return { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers): return requests.post(url, headers=headers, json=data) def _handle_error(self, response): if response.status_code in [402, 409, 500]: - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") else: if response.text: - error_message = json.loads(response.text).get('error', 'Unknown error occurred') - raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') - raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}') + error_message = json.loads(response.text).get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 539f2712bb..911d234641 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -7,58 +7,40 @@ from models.account import TenantAccountJoin, TenantAccountRole class BillingService: - base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') - secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') + base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") + secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @classmethod def get_info(cls, tenant_id: str): - params = {'tenant_id': tenant_id} + params = {"tenant_id": tenant_id} - billing_info = cls._send_request('GET', '/subscription/info', params=params) + billing_info = cls._send_request("GET", "/subscription/info", params=params) return billing_info @classmethod - def get_subscription(cls, plan: str, - interval: str, - prefilled_email: str = '', - tenant_id: str = ''): - params = { - 'plan': plan, - 'interval': interval, - 'prefilled_email': prefilled_email, - 'tenant_id': tenant_id - } - return cls._send_request('GET', '/subscription/payment-link', params=params) + def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""): + params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id} + return cls._send_request("GET", "/subscription/payment-link", params=params) @classmethod - def get_model_provider_payment_link(cls, - provider_name: str, - tenant_id: str, - account_id: str, - prefilled_email: str): + def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str): params = { - 'provider_name': provider_name, - 'tenant_id': tenant_id, - 'account_id': account_id, - 'prefilled_email': prefilled_email + "provider_name": provider_name, + "tenant_id": tenant_id, + "account_id": account_id, + "prefilled_email": prefilled_email, } - return cls._send_request('GET', '/model-provider/payment-link', params=params) + return cls._send_request("GET", "/model-provider/payment-link", params=params) @classmethod - def get_invoices(cls, prefilled_email: str = '', tenant_id: str = ''): - params = { - 'prefilled_email': prefilled_email, - 'tenant_id': tenant_id - } - return cls._send_request('GET', '/invoices', params=params) + def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""): + params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id} + return cls._send_request("GET", "/invoices", params=params) @classmethod def _send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Billing-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) @@ -69,10 +51,11 @@ class BillingService: def is_tenant_owner_or_admin(current_user): tenant_id = current_user.current_tenant_id - join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant_id, - TenantAccountJoin.account_id == current_user.id - ).first() + join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) + .first() + ) if not TenantAccountRole.is_privileged_role(join.role): - raise ValueError('Only team owner or team admin can perform this action') + raise ValueError("Only team owner or team admin can perform this action") diff --git a/api/services/code_based_extension_service.py b/api/services/code_based_extension_service.py index 7b0d50a835..f7597b7f1f 100644 --- a/api/services/code_based_extension_service.py +++ b/api/services/code_based_extension_service.py @@ -2,12 +2,15 @@ from extensions.ext_code_based_extension import code_based_extension class CodeBasedExtensionService: - @staticmethod def get_code_based_extension(module: str) -> list[dict]: module_extensions = code_based_extension.module_extensions(module) - return [{ - 'name': module_extension.name, - 'label': module_extension.label, - 'form_schema': module_extension.form_schema - } for module_extension in module_extensions if not module_extension.builtin] + return [ + { + "name": module_extension.name, + "label": module_extension.label, + "form_schema": module_extension.form_schema, + } + for module_extension in module_extensions + if not module_extension.builtin + ] diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 053d704e17..7bfe59afa0 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -15,22 +15,27 @@ from services.errors.message import MessageNotExistsError class ConversationService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, - invoke_from: InvokeFrom, - include_ids: Optional[list] = None, - exclude_ids: Optional[list] = None, - sort_by: str = '-updated_at') -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + invoke_from: InvokeFrom, + include_ids: Optional[list] = None, + exclude_ids: Optional[list] = None, + sort_by: str = "-updated_at", + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) base_query = db.session.query(Conversation).filter( Conversation.is_deleted == False, Conversation.app_id == app_model.id, - Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), Conversation.from_account_id == (user.id if isinstance(user, Account) else None), - or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value) + or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) if include_ids is not None: @@ -58,28 +63,26 @@ class ConversationService: has_more = False if len(conversations) == limit: current_page_last_conversation = conversations[-1] - rest_filter_condition = cls._build_filter_condition(sort_field, sort_direction, - current_page_last_conversation, is_next_page=True) + rest_filter_condition = cls._build_filter_condition( + sort_field, sort_direction, current_page_last_conversation, is_next_page=True + ) rest_count = base_query.filter(rest_filter_condition).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=conversations, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more) @classmethod def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]: - if sort_by.startswith('-'): + if sort_by.startswith("-"): return sort_by[1:], desc return sort_by, asc @classmethod - def _build_filter_condition(cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, - is_next_page: bool = False): + def _build_filter_condition( + cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, is_next_page: bool = False + ): field_value = getattr(reference_conversation, sort_field) if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): return getattr(Conversation, sort_field) < field_value @@ -87,8 +90,14 @@ class ConversationService: return getattr(Conversation, sort_field) > field_value @classmethod - def rename(cls, app_model: App, conversation_id: str, - user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool): + def rename( + cls, + app_model: App, + conversation_id: str, + user: Optional[Union[Account, EndUser]], + name: str, + auto_generate: bool, + ): conversation = cls.get_conversation(app_model, conversation_id, user) if auto_generate: @@ -103,11 +112,12 @@ class ConversationService: @classmethod def auto_generate_name(cls, app_model: App, conversation: Conversation): # get conversation first message - message = db.session.query(Message) \ - .filter( - Message.app_id == app_model.id, - Message.conversation_id == conversation.id - ).order_by(Message.created_at.asc()).first() + message = ( + db.session.query(Message) + .filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id) + .order_by(Message.created_at.asc()) + .first() + ) if not message: raise MessageNotExistsError() @@ -127,15 +137,18 @@ class ConversationService: @classmethod def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - conversation = db.session.query(Conversation) \ + conversation = ( + db.session.query(Conversation) .filter( - Conversation.id == conversation_id, - Conversation.app_id == app_model.id, - Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Conversation.from_account_id == (user.id if isinstance(user, Account) else None), - Conversation.is_deleted == False - ).first() + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), + Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Conversation.from_account_id == (user.id if isinstance(user, Account) else None), + Conversation.is_deleted == False, + ) + .first() + ) if not conversation: raise ConversationNotExistsError() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index d547014866..8649d0fea5 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -55,7 +55,6 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde class DatasetService: - @staticmethod def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None): query = Dataset.query.filter(Dataset.provider == provider, Dataset.tenant_id == tenant_id).order_by( @@ -64,10 +63,7 @@ class DatasetService: if user: # get permitted dataset ids - dataset_permission = DatasetPermission.query.filter_by( - account_id=user.id, - tenant_id=tenant_id - ).all() + dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all() permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None if user.current_role == TenantAccountRole.DATASET_OPERATOR: @@ -83,14 +79,17 @@ class DatasetService: db.or_( Dataset.permission == DatasetPermissionEnum.ALL_TEAM, db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), - db.and_(Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, Dataset.id.in_(permitted_dataset_ids)) + db.and_( + Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, + Dataset.id.in_(permitted_dataset_ids), + ), ) ) else: query = query.filter( db.or_( Dataset.permission == DatasetPermissionEnum.ALL_TEAM, - db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id) + db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), ) ) else: @@ -98,49 +97,40 @@ class DatasetService: query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) if search: - query = query.filter(Dataset.name.ilike(f'%{search}%')) + query = query.filter(Dataset.name.ilike(f"%{search}%")) if tag_ids: - target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids) + target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids) if target_ids: query = query.filter(Dataset.id.in_(target_ids)) else: return [], 0 - datasets = query.paginate( - page=page, - per_page=per_page, - max_per_page=100, - error_out=False - ) + datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) return datasets.items, datasets.total @staticmethod def get_process_rules(dataset_id): # get the latest process rule - dataset_process_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.dataset_id == dataset_id). \ - order_by(DatasetProcessRule.created_at.desc()). \ - limit(1). \ - one_or_none() + dataset_process_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.dataset_id == dataset_id) + .order_by(DatasetProcessRule.created_at.desc()) + .limit(1) + .one_or_none() + ) if dataset_process_rule: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict else: - mode = DocumentService.DEFAULT_RULES['mode'] - rules = DocumentService.DEFAULT_RULES['rules'] - return { - 'mode': mode, - 'rules': rules - } + mode = DocumentService.DEFAULT_RULES["mode"] + rules = DocumentService.DEFAULT_RULES["rules"] + return {"mode": mode, "rules": rules} @staticmethod def get_datasets_by_ids(ids, tenant_id): - datasets = Dataset.query.filter( - Dataset.id.in_(ids), - Dataset.tenant_id == tenant_id - ).paginate( + datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate( page=1, per_page=len(ids), max_per_page=len(ids), error_out=False ) return datasets.items, datasets.total @@ -149,15 +139,12 @@ class DatasetService: def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional[str], account: Account): # check if dataset name already exists if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): - raise DatasetNameDuplicateError( - f'Dataset with name {name} already exists.' - ) + raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None - if indexing_technique == 'high_quality': + if indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset = Dataset(name=name, indexing_technique=indexing_technique) # dataset = Dataset(name=name, provider=provider, config=config) @@ -172,20 +159,18 @@ class DatasetService: @staticmethod def get_dataset(dataset_id): - return Dataset.query.filter_by( - id=dataset_id - ).first() + return Dataset.query.filter_by(id=dataset_id).first() @staticmethod def check_dataset_model_setting(dataset): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ValueError( @@ -193,65 +178,56 @@ class DatasetService: "in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError( - f"The dataset in unavailable, due to: " - f"{ex.description}" - ) + raise ValueError(f"The dataset in unavailable, due to: " f"{ex.description}") @staticmethod - def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model:str): + def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=tenant_id, provider=embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=embedding_model + model=embedding_model, ) except LLMBadRequestError: raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." + "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError( - f"The dataset in unavailable, due to: " - f"{ex.description}" - ) - + raise ValueError(f"The dataset in unavailable, due to: " f"{ex.description}") @staticmethod def update_dataset(dataset_id, data, user): - data.pop('partial_member_list', None) - filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'} + data.pop("partial_member_list", None) + filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} dataset = DatasetService.get_dataset(dataset_id) DatasetService.check_dataset_permission(dataset, user) action = None - if dataset.indexing_technique != data['indexing_technique']: + if dataset.indexing_technique != data["indexing_technique"]: # if update indexing_technique - if data['indexing_technique'] == 'economy': - action = 'remove' - filtered_data['embedding_model'] = None - filtered_data['embedding_model_provider'] = None - filtered_data['collection_binding_id'] = None - elif data['indexing_technique'] == 'high_quality': - action = 'add' + if data["indexing_technique"] == "economy": + action = "remove" + filtered_data["embedding_model"] = None + filtered_data["embedding_model_provider"] = None + filtered_data["collection_binding_id"] = None + elif data["indexing_technique"] == "high_quality": + action = "add" # get embedding model setting try: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=data['embedding_model_provider'], + provider=data["embedding_model_provider"], model_type=ModelType.TEXT_EMBEDDING, - model=data['embedding_model'] + model=data["embedding_model"], ) - filtered_data['embedding_model'] = embedding_model.model - filtered_data['embedding_model_provider'] = embedding_model.provider + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) - filtered_data['collection_binding_id'] = dataset_collection_binding.id + filtered_data["collection_binding_id"] = dataset_collection_binding.id except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -260,24 +236,25 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) else: - if data['embedding_model_provider'] != dataset.embedding_model_provider or \ - data['embedding_model'] != dataset.embedding_model: - action = 'update' + if ( + data["embedding_model_provider"] != dataset.embedding_model_provider + or data["embedding_model"] != dataset.embedding_model + ): + action = "update" try: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=data['embedding_model_provider'], + provider=data["embedding_model_provider"], model_type=ModelType.TEXT_EMBEDDING, - model=data['embedding_model'] + model=data["embedding_model"], ) - filtered_data['embedding_model'] = embedding_model.model - filtered_data['embedding_model_provider'] = embedding_model.provider + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) - filtered_data['collection_binding_id'] = dataset_collection_binding.id + filtered_data["collection_binding_id"] = dataset_collection_binding.id except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -286,11 +263,11 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) - filtered_data['updated_by'] = user.id - filtered_data['updated_at'] = datetime.datetime.now() + filtered_data["updated_by"] = user.id + filtered_data["updated_at"] = datetime.datetime.now() # update Retrieval model - filtered_data['retrieval_model'] = data['retrieval_model'] + filtered_data["retrieval_model"] = data["retrieval_model"] dataset.query.filter_by(id=dataset_id).update(filtered_data) @@ -301,7 +278,6 @@ class DatasetService: @staticmethod def delete_dataset(dataset_id, user): - dataset = DatasetService.get_dataset(dataset_id) if dataset is None: @@ -325,72 +301,57 @@ class DatasetService: @staticmethod def check_dataset_permission(dataset, user): if dataset.tenant_id != user.current_tenant_id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) - if dataset.permission == 'partial_members': - user_permission = DatasetPermission.query.filter_by( - dataset_id=dataset.id, account_id=user.id - ).first() + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") + if dataset.permission == "partial_members": + user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() if not user_permission and dataset.tenant_id != user.current_tenant_id and dataset.created_by != user.id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None): if dataset.permission == DatasetPermissionEnum.ONLY_ME: if dataset.created_by != user.id: - raise NoPermissionError('You do not have permission to access this dataset.') + raise NoPermissionError("You do not have permission to access this dataset.") elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: if not any( dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() ): - raise NoPermissionError('You do not have permission to access this dataset.') + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def get_dataset_queries(dataset_id: str, page: int, per_page: int): - dataset_queries = DatasetQuery.query.filter_by(dataset_id=dataset_id) \ - .order_by(db.desc(DatasetQuery.created_at)) \ - .paginate( - page=page, per_page=per_page, max_per_page=100, error_out=False + dataset_queries = ( + DatasetQuery.query.filter_by(dataset_id=dataset_id) + .order_by(db.desc(DatasetQuery.created_at)) + .paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) ) return dataset_queries.items, dataset_queries.total @staticmethod def get_related_apps(dataset_id: str): - return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \ - .order_by(db.desc(AppDatasetJoin.created_at)).all() + return ( + AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) + .order_by(db.desc(AppDatasetJoin.created_at)) + .all() + ) class DocumentService: DEFAULT_RULES = { - 'mode': 'custom', - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': False} + "mode": "custom", + "rules": { + "pre_processing_rules": [ + {"id": "remove_extra_spaces", "enabled": True}, + {"id": "remove_urls_emails", "enabled": False}, ], - 'segmentation': { - 'delimiter': '\n', - 'max_tokens': 500, - 'chunk_overlap': 50 - } - } + "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, + }, } DOCUMENT_METADATA_SCHEMA = { @@ -483,58 +444,55 @@ class DocumentService: "commit_date": str, "commit_author": str, }, - "others": dict + "others": dict, } @staticmethod def get_document(dataset_id: str, document_id: str) -> Optional[Document]: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) return document @staticmethod def get_document_by_id(document_id: str) -> Optional[Document]: - document = db.session.query(Document).filter( - Document.id == document_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id).first() return document @staticmethod def get_document_by_dataset_id(dataset_id: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.enabled == True - ).all() + documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all() return documents @staticmethod def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.indexing_status.in_(['error', 'paused']) - ).all() + documents = ( + db.session.query(Document) + .filter(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + .all() + ) return documents @staticmethod def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.batch == batch, - Document.dataset_id == dataset_id, - Document.tenant_id == current_user.current_tenant_id - ).all() + documents = ( + db.session.query(Document) + .filter( + Document.batch == batch, + Document.dataset_id == dataset_id, + Document.tenant_id == current_user.current_tenant_id, + ) + .all() + ) return documents @staticmethod def get_document_file_detail(file_id: str): - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == file_id). \ - one_or_none() + file_detail = db.session.query(UploadFile).filter(UploadFile.id == file_id).one_or_none() return file_detail @staticmethod @@ -548,13 +506,14 @@ class DocumentService: def delete_document(document): # trigger document_was_deleted signal file_id = None - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": if document.data_source_info: data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] - document_was_deleted.send(document.id, dataset_id=document.dataset_id, - doc_form=document.doc_form, file_id=file_id) + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] + document_was_deleted.send( + document.id, dataset_id=document.dataset_id, doc_form=document.doc_form, file_id=file_id + ) db.session.delete(document) db.session.commit() @@ -563,15 +522,15 @@ class DocumentService: def rename_document(dataset_id: str, document_id: str, name: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise ValueError('Dataset not found.') + raise ValueError("Dataset not found.") document = DocumentService.get_document(dataset_id, document_id) if not document: - raise ValueError('Document not found.') + raise ValueError("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise ValueError('No permission.') + raise ValueError("No permission.") document.name = name @@ -592,7 +551,7 @@ class DocumentService: db.session.add(document) db.session.commit() # set document paused flag - indexing_cache_key = 'document_{}_is_paused'.format(document.id) + indexing_cache_key = "document_{}_is_paused".format(document.id) redis_client.setnx(indexing_cache_key, "True") @staticmethod @@ -607,7 +566,7 @@ class DocumentService: db.session.add(document) db.session.commit() # delete paused flag - indexing_cache_key = 'document_{}_is_paused'.format(document.id) + indexing_cache_key = "document_{}_is_paused".format(document.id) redis_client.delete(indexing_cache_key) # trigger async task recover_document_indexing_task.delay(document.dataset_id, document.id) @@ -616,12 +575,12 @@ class DocumentService: def retry_document(dataset_id: str, documents: list[Document]): for document in documents: # add retry flag - retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id) + retry_indexing_cache_key = "document_{}_is_retried".format(document.id) cache_result = redis_client.get(retry_indexing_cache_key) if cache_result is not None: raise ValueError("Document is being retried, please try again later") # retry document indexing - document.indexing_status = 'waiting' + document.indexing_status = "waiting" db.session.add(document) db.session.commit() @@ -633,14 +592,14 @@ class DocumentService: @staticmethod def sync_website_document(dataset_id: str, document: Document): # add sync flag - sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id) + sync_indexing_cache_key = "document_{}_is_sync".format(document.id) cache_result = redis_client.get(sync_indexing_cache_key) if cache_result is not None: raise ValueError("Document is being synced, please try again later") # sync document indexing - document.indexing_status = 'waiting' + document.indexing_status = "waiting" data_source_info = document.data_source_info_dict - data_source_info['mode'] = 'scrape' + data_source_info["mode"] = "scrape" document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) db.session.add(document) db.session.commit() @@ -659,27 +618,28 @@ class DocumentService: @staticmethod def save_document_with_dataset_id( - dataset: Dataset, document_data: dict, - account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, - created_from: str = 'web' + dataset: Dataset, + document_data: dict, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", ): - # check document limit features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if 'original_document_id' not in document_data or not document_data['original_document_id']: + if "original_document_id" not in document_data or not document_data["original_document_id"]: count = 0 if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] count = len(upload_file_list) elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - count = count + len(notion_info['pages']) + count = count + len(notion_info["pages"]) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - count = len(website_info['urls']) + website_info = document_data["data_source"]["info_list"]["website_info_list"] + count = len(website_info["urls"]) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -691,42 +651,41 @@ class DocumentService: dataset.data_source_type = document_data["data_source"]["type"] if not dataset.indexing_technique: - if 'indexing_technique' not in document_data \ - or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST: + if ( + "indexing_technique" not in document_data + or document_data["indexing_technique"] not in Dataset.INDEXING_TECHNIQUE_LIST + ): raise ValueError("Indexing technique is required") dataset.indexing_technique = document_data["indexing_technique"] - if document_data["indexing_technique"] == 'high_quality': + if document_data["indexing_technique"] == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) dataset.collection_binding_id = dataset_collection_binding.id if not dataset.retrieval_model: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } - dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get( - 'retrieval_model' - ) else default_retrieval_model + dataset.retrieval_model = ( + document_data.get("retrieval_model") + if document_data.get("retrieval_model") + else default_retrieval_model + ) documents = [] - batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) if document_data.get("original_document_id"): document = DocumentService.update_document_with_dataset_id(dataset, document_data, account) documents.append(document) @@ -739,14 +698,14 @@ class DocumentService: dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(process_rule["rules"]), - created_by=account.id + created_by=account.id, ) elif process_rule["mode"] == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id + created_by=account.id, ) db.session.add(dataset_process_rule) db.session.commit() @@ -754,12 +713,13 @@ class DocumentService: document_ids = [] duplicate_document_ids = [] if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] for file_id in upload_file_list: - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: @@ -770,34 +730,39 @@ class DocumentService: "upload_file_id": file_id, } # check duplicate - if document_data.get('duplicate', False): + if document_data.get("duplicate", False): document = Document.query.filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, - data_source_type='upload_file', + data_source_type="upload_file", enabled=True, - name=file_name + name=file_name, ).first() if document: document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = datetime.datetime.utcnow() document.created_from = created_from - document.doc_form = document_data['doc_form'] - document.doc_language = document_data['doc_language'] + document.doc_form = document_data["doc_form"] + document.doc_language = document_data["doc_language"] document.data_source_info = json.dumps(data_source_info) document.batch = batch - document.indexing_status = 'waiting' + document.indexing_status = "waiting" db.session.add(document) documents.append(document) duplicate_document_ids.append(document.id) continue document = DocumentService.build_document( - dataset, dataset_process_rule.id, + dataset, + dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], document_data["doc_language"], - data_source_info, created_from, position, - account, file_name, batch + data_source_info, + created_from, + position, + account, + file_name, + batch, ) db.session.add(document) db.session.flush() @@ -805,47 +770,52 @@ class DocumentService: documents.append(document) position += 1 elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] exist_page_ids = [] exist_document = {} documents = Document.query.filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, - data_source_type='notion_import', - enabled=True + data_source_type="notion_import", + enabled=True, ).all() if documents: for document in documents: data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info['notion_page_id']) - exist_document[data_source_info['notion_page_id']] = document.id + exist_page_ids.append(data_source_info["notion_page_id"]) + exist_document[data_source_info["notion_page_id"]] = document.id for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] + workspace_id = notion_info["workspace_id"] data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise ValueError('Data source binding not found.') - for page in notion_info['pages']: - if page['page_id'] not in exist_page_ids: + raise ValueError("Data source binding not found.") + for page in notion_info["pages"]: + if page["page_id"] not in exist_page_ids: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page['page_id'], - "notion_page_icon": page['page_icon'], - "type": page['type'] + "notion_page_id": page["page_id"], + "notion_page_icon": page["page_icon"], + "type": page["type"], } document = DocumentService.build_document( - dataset, dataset_process_rule.id, + dataset, + dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], document_data["doc_language"], - data_source_info, created_from, position, - account, page['page_name'], batch + data_source_info, + created_from, + position, + account, + page["page_name"], + batch, ) db.session.add(document) db.session.flush() @@ -853,32 +823,37 @@ class DocumentService: documents.append(document) position += 1 else: - exist_document.pop(page['page_id']) + exist_document.pop(page["page_id"]) # delete not selected documents if len(exist_document) > 0: clean_notion_document_task.delay(list(exist_document.values()), dataset.id) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - urls = website_info['urls'] + website_info = document_data["data_source"]["info_list"]["website_info_list"] + urls = website_info["urls"] for url in urls: data_source_info = { - 'url': url, - 'provider': website_info['provider'], - 'job_id': website_info['job_id'], - 'only_main_content': website_info.get('only_main_content', False), - 'mode': 'crawl', + "url": url, + "provider": website_info["provider"], + "job_id": website_info["job_id"], + "only_main_content": website_info.get("only_main_content", False), + "mode": "crawl", } if len(url) > 255: - document_name = url[:200] + '...' + document_name = url[:200] + "..." else: document_name = url document = DocumentService.build_document( - dataset, dataset_process_rule.id, + dataset, + dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], document_data["doc_language"], - data_source_info, created_from, position, - account, document_name, batch + data_source_info, + created_from, + position, + account, + document_name, + batch, ) db.session.add(document) db.session.flush() @@ -900,15 +875,22 @@ class DocumentService: can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size if count > can_upload_size: raise ValueError( - f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.' + f"You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded." ) @staticmethod def build_document( - dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str, - document_language: str, data_source_info: dict, created_from: str, position: int, + dataset: Dataset, + process_rule_id: str, + data_source_type: str, + document_form: str, + document_language: str, + data_source_info: dict, + created_from: str, + position: int, account: Account, - name: str, batch: str + name: str, + batch: str, ): document = Document( tenant_id=dataset.tenant_id, @@ -922,7 +904,7 @@ class DocumentService: created_from=created_from, created_by=account.id, doc_form=document_form, - doc_language=document_language + doc_language=document_language, ) return document @@ -932,54 +914,57 @@ class DocumentService: Document.completed_at.isnot(None), Document.enabled == True, Document.archived == False, - Document.tenant_id == current_user.current_tenant_id + Document.tenant_id == current_user.current_tenant_id, ).count() return documents_count @staticmethod def update_document_with_dataset_id( - dataset: Dataset, document_data: dict, - account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, - created_from: str = 'web' + dataset: Dataset, + document_data: dict, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", ): DatasetService.check_dataset_model_setting(dataset) document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) - if document.display_status != 'available': + if document.display_status != "available": raise ValueError("Document is not available") # update document name - if document_data.get('name'): - document.name = document_data['name'] + if document_data.get("name"): + document.name = document_data["name"] # save process rule - if document_data.get('process_rule'): + if document_data.get("process_rule"): process_rule = document_data["process_rule"] if process_rule["mode"] == "custom": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(process_rule["rules"]), - created_by=account.id + created_by=account.id, ) elif process_rule["mode"] == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id + created_by=account.id, ) db.session.add(dataset_process_rule) db.session.commit() document.dataset_process_rule_id = dataset_process_rule.id # update document data source - if document_data.get('data_source'): - file_name = '' + if document_data.get("data_source"): + file_name = "" data_source_info = {} if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] for file_id in upload_file_list: - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: @@ -990,42 +975,42 @@ class DocumentService: "upload_file_id": file_id, } elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] + workspace_id = notion_info["workspace_id"] data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise ValueError('Data source binding not found.') - for page in notion_info['pages']: + raise ValueError("Data source binding not found.") + for page in notion_info["pages"]: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page['page_id'], - "notion_page_icon": page['page_icon'], - "type": page['type'] + "notion_page_id": page["page_id"], + "notion_page_icon": page["page_icon"], + "type": page["type"], } elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - urls = website_info['urls'] + website_info = document_data["data_source"]["info_list"]["website_info_list"] + urls = website_info["urls"] for url in urls: data_source_info = { - 'url': url, - 'provider': website_info['provider'], - 'job_id': website_info['job_id'], - 'only_main_content': website_info.get('only_main_content', False), - 'mode': 'crawl', + "url": url, + "provider": website_info["provider"], + "job_id": website_info["job_id"], + "only_main_content": website_info.get("only_main_content", False), + "mode": "crawl", } document.data_source_type = document_data["data_source"]["type"] document.data_source_info = json.dumps(data_source_info) document.name = file_name # update document to be waiting - document.indexing_status = 'waiting' + document.indexing_status = "waiting" document.completed_at = None document.processing_started_at = None document.parsing_completed_at = None @@ -1033,13 +1018,11 @@ class DocumentService: document.splitting_completed_at = None document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.created_from = created_from - document.doc_form = document_data['doc_form'] + document.doc_form = document_data["doc_form"] db.session.add(document) db.session.commit() # update document segment - update_params = { - DocumentSegment.status: 're_segment' - } + update_params = {DocumentSegment.status: "re_segment"} DocumentSegment.query.filter_by(document_id=document.id).update(update_params) db.session.commit() # trigger async task @@ -1053,15 +1036,15 @@ class DocumentService: if features.billing.enabled: count = 0 if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] count = len(upload_file_list) elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - count = count + len(notion_info['pages']) + count = count + len(notion_info["pages"]) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - count = len(website_info['urls']) + website_info = document_data["data_source"]["info_list"]["website_info_list"] + count = len(website_info["urls"]) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -1071,42 +1054,37 @@ class DocumentService: embedding_model = None dataset_collection_binding_id = None retrieval_model = None - if document_data['indexing_technique'] == 'high_quality': + if document_data["indexing_technique"] == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) dataset_collection_binding_id = dataset_collection_binding.id - if document_data.get('retrieval_model'): - retrieval_model = document_data['retrieval_model'] + if document_data.get("retrieval_model"): + retrieval_model = document_data["retrieval_model"] else: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } retrieval_model = default_retrieval_model # save dataset dataset = Dataset( tenant_id=tenant_id, - name='', + name="", data_source_type=document_data["data_source"]["type"], indexing_technique=document_data["indexing_technique"], created_by=account.id, embedding_model=embedding_model.model if embedding_model else None, embedding_model_provider=embedding_model.provider if embedding_model else None, collection_binding_id=dataset_collection_binding_id, - retrieval_model=retrieval_model + retrieval_model=retrieval_model, ) db.session.add(dataset) @@ -1116,236 +1094,259 @@ class DocumentService: cut_length = 18 cut_name = documents[0].name[:cut_length] - dataset.name = cut_name + '...' - dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name + dataset.name = cut_name + "..." + dataset.description = "useful for when you want to answer queries about the " + documents[0].name db.session.commit() return dataset, documents, batch @classmethod def document_create_args_validate(cls, args: dict): - if 'original_document_id' not in args or not args['original_document_id']: + if "original_document_id" not in args or not args["original_document_id"]: DocumentService.data_source_args_validate(args) DocumentService.process_rule_args_validate(args) else: - if ('data_source' not in args and not args['data_source']) \ - and ('process_rule' not in args and not args['process_rule']): + if ("data_source" not in args and not args["data_source"]) and ( + "process_rule" not in args and not args["process_rule"] + ): raise ValueError("Data source or Process rule is required") else: - if args.get('data_source'): + if args.get("data_source"): DocumentService.data_source_args_validate(args) - if args.get('process_rule'): + if args.get("process_rule"): DocumentService.process_rule_args_validate(args) @classmethod def data_source_args_validate(cls, args: dict): - if 'data_source' not in args or not args['data_source']: + if "data_source" not in args or not args["data_source"]: raise ValueError("Data source is required") - if not isinstance(args['data_source'], dict): + if not isinstance(args["data_source"], dict): raise ValueError("Data source is invalid") - if 'type' not in args['data_source'] or not args['data_source']['type']: + if "type" not in args["data_source"] or not args["data_source"]["type"]: raise ValueError("Data source type is required") - if args['data_source']['type'] not in Document.DATA_SOURCES: + if args["data_source"]["type"] not in Document.DATA_SOURCES: raise ValueError("Data source type is invalid") - if 'info_list' not in args['data_source'] or not args['data_source']['info_list']: + if "info_list" not in args["data_source"] or not args["data_source"]["info_list"]: raise ValueError("Data source info is required") - if args['data_source']['type'] == 'upload_file': - if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'file_info_list']: + if args["data_source"]["type"] == "upload_file": + if ( + "file_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["file_info_list"] + ): raise ValueError("File source info is required") - if args['data_source']['type'] == 'notion_import': - if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'notion_info_list']: + if args["data_source"]["type"] == "notion_import": + if ( + "notion_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["notion_info_list"] + ): raise ValueError("Notion source info is required") - if args['data_source']['type'] == 'website_crawl': - if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'website_info_list']: + if args["data_source"]["type"] == "website_crawl": + if ( + "website_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["website_info_list"] + ): raise ValueError("Website source info is required") @classmethod def process_rule_args_validate(cls, args: dict): - if 'process_rule' not in args or not args['process_rule']: + if "process_rule" not in args or not args["process_rule"]: raise ValueError("Process rule is required") - if not isinstance(args['process_rule'], dict): + if not isinstance(args["process_rule"], dict): raise ValueError("Process rule is invalid") - if 'mode' not in args['process_rule'] or not args['process_rule']['mode']: + if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: raise ValueError("Process rule mode is required") - if args['process_rule']['mode'] not in DatasetProcessRule.MODES: + if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args['process_rule']['mode'] == 'automatic': - args['process_rule']['rules'] = {} + if args["process_rule"]["mode"] == "automatic": + args["process_rule"]["rules"] = {} else: - if 'rules' not in args['process_rule'] or not args['process_rule']['rules']: + if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: raise ValueError("Process rule rules is required") - if not isinstance(args['process_rule']['rules'], dict): + if not isinstance(args["process_rule"]["rules"], dict): raise ValueError("Process rule rules is invalid") - if 'pre_processing_rules' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['pre_processing_rules'] is None: + if ( + "pre_processing_rules" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["pre_processing_rules"] is None + ): raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list): + if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): raise ValueError("Process rule pre_processing_rules is invalid") unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']: - if 'id' not in pre_processing_rule or not pre_processing_rule['id']: + for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: + if "id" not in pre_processing_rule or not pre_processing_rule["id"]: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES: + if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: raise ValueError("Process rule pre_processing_rules id is invalid") - if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None: + if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: raise ValueError("Process rule pre_processing_rules enabled is required") - if not isinstance(pre_processing_rule['enabled'], bool): + if not isinstance(pre_processing_rule["enabled"], bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule - args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values()) + args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) - if 'segmentation' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['segmentation'] is None: + if ( + "segmentation" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["segmentation"] is None + ): raise ValueError("Process rule segmentation is required") - if not isinstance(args['process_rule']['rules']['segmentation'], dict): + if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): raise ValueError("Process rule segmentation is invalid") - if 'separator' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['separator']: + if ( + "separator" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["separator"] + ): raise ValueError("Process rule segmentation separator is required") - if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): raise ValueError("Process rule segmentation separator is invalid") - if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['max_tokens']: + if ( + "max_tokens" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + ): raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") @classmethod def estimate_args_validate(cls, args: dict): - if 'info_list' not in args or not args['info_list']: + if "info_list" not in args or not args["info_list"]: raise ValueError("Data source info is required") - if not isinstance(args['info_list'], dict): + if not isinstance(args["info_list"], dict): raise ValueError("Data info is invalid") - if 'process_rule' not in args or not args['process_rule']: + if "process_rule" not in args or not args["process_rule"]: raise ValueError("Process rule is required") - if not isinstance(args['process_rule'], dict): + if not isinstance(args["process_rule"], dict): raise ValueError("Process rule is invalid") - if 'mode' not in args['process_rule'] or not args['process_rule']['mode']: + if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: raise ValueError("Process rule mode is required") - if args['process_rule']['mode'] not in DatasetProcessRule.MODES: + if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args['process_rule']['mode'] == 'automatic': - args['process_rule']['rules'] = {} + if args["process_rule"]["mode"] == "automatic": + args["process_rule"]["rules"] = {} else: - if 'rules' not in args['process_rule'] or not args['process_rule']['rules']: + if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: raise ValueError("Process rule rules is required") - if not isinstance(args['process_rule']['rules'], dict): + if not isinstance(args["process_rule"]["rules"], dict): raise ValueError("Process rule rules is invalid") - if 'pre_processing_rules' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['pre_processing_rules'] is None: + if ( + "pre_processing_rules" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["pre_processing_rules"] is None + ): raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list): + if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): raise ValueError("Process rule pre_processing_rules is invalid") unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']: - if 'id' not in pre_processing_rule or not pre_processing_rule['id']: + for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: + if "id" not in pre_processing_rule or not pre_processing_rule["id"]: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES: + if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: raise ValueError("Process rule pre_processing_rules id is invalid") - if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None: + if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: raise ValueError("Process rule pre_processing_rules enabled is required") - if not isinstance(pre_processing_rule['enabled'], bool): + if not isinstance(pre_processing_rule["enabled"], bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule - args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values()) + args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) - if 'segmentation' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['segmentation'] is None: + if ( + "segmentation" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["segmentation"] is None + ): raise ValueError("Process rule segmentation is required") - if not isinstance(args['process_rule']['rules']['segmentation'], dict): + if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): raise ValueError("Process rule segmentation is invalid") - if 'separator' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['separator']: + if ( + "separator" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["separator"] + ): raise ValueError("Process rule segmentation separator is required") - if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): raise ValueError("Process rule segmentation separator is invalid") - if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['max_tokens']: + if ( + "max_tokens" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + ): raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") class SegmentService: @classmethod def segment_create_args_validate(cls, args: dict, document: Document): - if document.doc_form == 'qa_model': - if 'answer' not in args or not args['answer']: + if document.doc_form == "qa_model": + if "answer" not in args or not args["answer"]: raise ValueError("Answer is required") - if not args['answer'].strip(): + if not args["answer"].strip(): raise ValueError("Answer is empty") - if 'content' not in args or not args['content'] or not args['content'].strip(): + if "content" not in args or not args["content"] or not args["content"].strip(): raise ValueError("Content is empty") @classmethod def create_segment(cls, args: dict, document: Document, dataset: Dataset): - content = args['content'] + content = args["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) - lock_name = 'add_segment_lock_document_id_{}'.format(document.id) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + lock_name = "add_segment_lock_document_id_{}".format(document.id) with redis_client.lock(lock_name, timeout=600): - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == document.id) + .scalar() + ) segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1356,25 +1357,25 @@ class SegmentService: content=content, word_count=len(content), tokens=tokens, - status='completed', + status="completed", indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - created_by=current_user.id + created_by=current_user.id, ) - if document.doc_form == 'qa_model': - segment_document.answer = args['answer'] + if document.doc_form == "qa_model": + segment_document.answer = args["answer"] db.session.add(segment_document) db.session.commit() # save vector index try: - VectorService.create_segments_vector([args['keywords']], [segment_document], dataset) + VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset) except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment_document.status = 'error' + segment_document.status = "error" segment_document.error = str(e) db.session.commit() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() @@ -1382,33 +1383,33 @@ class SegmentService: @classmethod def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): - lock_name = 'multi_add_segment_lock_document_id_{}'.format(document.id) + lock_name = "multi_add_segment_lock_document_id_{}".format(document.id) with redis_client.lock(lock_name, timeout=600): embedding_model = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == document.id) + .scalar() + ) pre_segment_data_list = [] segment_data_list = [] keywords_list = [] for segment_item in segments: - content = segment_item['content'] + content = segment_item["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality' and embedding_model: + if dataset.indexing_technique == "high_quality" and embedding_model: # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1419,19 +1420,19 @@ class SegmentService: content=content, word_count=len(content), tokens=tokens, - status='completed', + status="completed", indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - created_by=current_user.id + created_by=current_user.id, ) - if document.doc_form == 'qa_model': - segment_document.answer = segment_item['answer'] + if document.doc_form == "qa_model": + segment_document.answer = segment_item["answer"] db.session.add(segment_document) segment_data_list.append(segment_document) pre_segment_data_list.append(segment_document) - if 'keywords' in segment_item: - keywords_list.append(segment_item['keywords']) + if "keywords" in segment_item: + keywords_list.append(segment_item["keywords"]) else: keywords_list.append(None) @@ -1443,19 +1444,19 @@ class SegmentService: for segment_document in segment_data_list: segment_document.enabled = False segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment_document.status = 'error' + segment_document.status = "error" segment_document.error = str(e) db.session.commit() return segment_data_list @classmethod def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") - if 'enabled' in args and args['enabled'] is not None: - action = args['enabled'] + if "enabled" in args and args["enabled"] is not None: + action = args["enabled"] if segment.enabled != action: if not action: segment.enabled = action @@ -1468,25 +1469,25 @@ class SegmentService: disable_segment_from_index_task.delay(segment.id) return segment if not segment.enabled: - if 'enabled' in args and args['enabled'] is not None: - if not args['enabled']: + if "enabled" in args and args["enabled"] is not None: + if not args["enabled"]: raise ValueError("Can't update disabled segment") else: raise ValueError("Can't update disabled segment") try: - content = args['content'] + content = args["content"] if segment.content == content: - if document.doc_form == 'qa_model': - segment.answer = args['answer'] - if args.get('keywords'): - segment.keywords = args['keywords'] + if document.doc_form == "qa_model": + segment.answer = args["answer"] + if args.get("keywords"): + segment.keywords = args["keywords"] segment.enabled = True segment.disabled_at = None segment.disabled_by = None db.session.add(segment) db.session.commit() # update segment index task - if 'keywords' in args: + if "keywords" in args: keyword = Keyword(dataset) keyword.delete_by_ids([segment.index_node_id]) document = RAGDocument( @@ -1496,30 +1497,28 @@ class SegmentService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) - keyword.add_texts([document], keywords_list=[args['keywords']]) + keyword.add_texts([document], keywords_list=[args["keywords"]]) else: segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) segment.content = content segment.index_node_hash = segment_hash segment.word_count = len(content) segment.tokens = tokens - segment.status = 'completed' + segment.status = "completed" segment.indexing_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.completed_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.updated_by = current_user.id @@ -1527,18 +1526,18 @@ class SegmentService: segment.enabled = True segment.disabled_at = None segment.disabled_by = None - if document.doc_form == 'qa_model': - segment.answer = args['answer'] + if document.doc_form == "qa_model": + segment.answer = args["answer"] db.session.add(segment) db.session.commit() # update segment vector index - VectorService.update_segment_vector(args['keywords'], segment, dataset) + VectorService.update_segment_vector(args["keywords"], segment, dataset) except Exception as e: logging.exception("update segment index failed") segment.enabled = False segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.status = 'error' + segment.status = "error" segment.error = str(e) db.session.commit() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() @@ -1546,7 +1545,7 @@ class SegmentService: @classmethod def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): - indexing_cache_key = 'segment_{}_delete_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_delete_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is deleting.") @@ -1563,24 +1562,25 @@ class SegmentService: class DatasetCollectionBindingService: @classmethod def get_dataset_collection_binding( - cls, provider_name: str, model_name: str, - collection_type: str = 'dataset' + cls, provider_name: str, model_name: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter( - DatasetCollectionBinding.provider_name == provider_name, - DatasetCollectionBinding.model_name == model_name, - DatasetCollectionBinding.type == collection_type - ). \ - order_by(DatasetCollectionBinding.created_at). \ - first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.provider_name == provider_name, + DatasetCollectionBinding.model_name == model_name, + DatasetCollectionBinding.type == collection_type, + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) if not dataset_collection_binding: dataset_collection_binding = DatasetCollectionBinding( provider_name=provider_name, model_name=model_name, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), - type=collection_type + type=collection_type, ) db.session.add(dataset_collection_binding) db.session.commit() @@ -1588,16 +1588,16 @@ class DatasetCollectionBindingService: @classmethod def get_dataset_collection_binding_by_id_and_type( - cls, collection_binding_id: str, - collection_type: str = 'dataset' + cls, collection_binding_id: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter( - DatasetCollectionBinding.id == collection_binding_id, - DatasetCollectionBinding.type == collection_type - ). \ - order_by(DatasetCollectionBinding.created_at). \ - first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) return dataset_collection_binding @@ -1605,11 +1605,13 @@ class DatasetCollectionBindingService: class DatasetPermissionService: @classmethod def get_dataset_partial_member_list(cls, dataset_id): - user_list_query = db.session.query( - DatasetPermission.account_id, - ).filter( - DatasetPermission.dataset_id == dataset_id - ).all() + user_list_query = ( + db.session.query( + DatasetPermission.account_id, + ) + .filter(DatasetPermission.dataset_id == dataset_id) + .all() + ) user_list = [] for user in user_list_query: @@ -1626,7 +1628,7 @@ class DatasetPermissionService: permission = DatasetPermission( tenant_id=tenant_id, dataset_id=dataset_id, - account_id=user['user_id'], + account_id=user["user_id"], ) permissions.append(permission) @@ -1639,19 +1641,19 @@ class DatasetPermissionService: @classmethod def check_permission(cls, user, dataset, requested_permission, requested_partial_member_list): if not user.is_dataset_editor: - raise NoPermissionError('User does not have permission to edit this dataset.') + raise NoPermissionError("User does not have permission to edit this dataset.") if user.is_dataset_operator and dataset.permission != requested_permission: - raise NoPermissionError('Dataset operators cannot change the dataset permissions.') + raise NoPermissionError("Dataset operators cannot change the dataset permissions.") - if user.is_dataset_operator and requested_permission == 'partial_members': + if user.is_dataset_operator and requested_permission == "partial_members": if not requested_partial_member_list: - raise ValueError('Partial member list is required when setting to partial members.') + raise ValueError("Partial member list is required when setting to partial members.") local_member_list = cls.get_dataset_partial_member_list(dataset.id) - request_member_list = [user['user_id'] for user in requested_partial_member_list] + request_member_list = [user["user_id"] for user in requested_partial_member_list] if set(local_member_list) != set(request_member_list): - raise ValueError('Dataset operators cannot change the dataset permissions.') + raise ValueError("Dataset operators cannot change the dataset permissions.") @classmethod def clear_partial_member_list(cls, dataset_id): diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index c483d28152..ddee52164b 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -4,15 +4,12 @@ import requests class EnterpriseRequest: - base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL') - secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY') + base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") + secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") @classmethod def send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Enterprise-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 115d0d5523..abc01ddf8f 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -2,7 +2,10 @@ from services.enterprise.base import EnterpriseRequest class EnterpriseService: - @classmethod def get_info(cls): - return EnterpriseRequest.send_request('GET', '/info') + return EnterpriseRequest.send_request("GET", "/info") + + @classmethod + def get_app_web_sso_enabled(cls, app_code): + return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}") diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index e5e4d7e235..c519f0b0e5 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -22,14 +22,16 @@ class CustomConfigurationStatus(Enum): """ Enum class for custom configuration status. """ - ACTIVE = 'active' - NO_CONFIGURE = 'no-configure' + + ACTIVE = "active" + NO_CONFIGURE = "no-configure" class CustomConfigurationResponse(BaseModel): """ Model class for provider custom configuration response. """ + status: CustomConfigurationStatus @@ -37,6 +39,7 @@ class SystemConfigurationResponse(BaseModel): """ Model class for provider system configuration response. """ + enabled: bool current_quota_type: Optional[ProviderQuotaType] = None quota_configurations: list[QuotaConfiguration] = [] @@ -46,6 +49,7 @@ class ProviderResponse(BaseModel): """ Model class for provider response. """ + provider: str label: I18nObject description: Optional[I18nObject] = None @@ -67,18 +71,15 @@ class ProviderResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -86,6 +87,7 @@ class ProviderWithModelsResponse(BaseModel): """ Model class for provider with models response. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -96,18 +98,15 @@ class ProviderWithModelsResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -119,18 +118,15 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -138,6 +134,7 @@ class DefaultModelResponse(BaseModel): """ Default model entity. """ + model: str model_type: ModelType provider: SimpleProviderEntityResponse @@ -150,6 +147,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): """ Model with provider entity. """ + provider: SimpleProviderEntityResponse def __init__(self, model: ModelWithProviderEntity) -> None: diff --git a/api/services/errors/account.py b/api/services/errors/account.py index ddc2dbdea8..cae31c5066 100644 --- a/api/services/errors/account.py +++ b/api/services/errors/account.py @@ -55,4 +55,3 @@ class RoleAlreadyAssignedError(BaseServiceError): class RateLimitExceededError(BaseServiceError): pass - diff --git a/api/services/errors/base.py b/api/services/errors/base.py index f5d41e17f1..1fed71cf9e 100644 --- a/api/services/errors/base.py +++ b/api/services/errors/base.py @@ -1,3 +1,3 @@ class BaseServiceError(Exception): def __init__(self, description: str = None): - self.description = description \ No newline at end of file + self.description = description diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 83e675a9d2..4d5812c6c6 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -6,8 +6,8 @@ from services.enterprise.enterprise_service import EnterpriseService class SubscriptionModel(BaseModel): - plan: str = 'sandbox' - interval: str = '' + plan: str = "sandbox" + interval: str = "" class BillingModel(BaseModel): @@ -27,7 +27,7 @@ class FeatureModel(BaseModel): vector_space: LimitationModel = LimitationModel(size=0, limit=5) annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) - docs_processing: str = 'standard' + docs_processing: str = "standard" can_replace_logo: bool = False model_load_balancing_enabled: bool = False dataset_operator_enabled: bool = False @@ -38,13 +38,13 @@ class FeatureModel(BaseModel): class SystemFeatureModel(BaseModel): sso_enforced_for_signin: bool = False - sso_enforced_for_signin_protocol: str = '' + sso_enforced_for_signin_protocol: str = "" sso_enforced_for_web: bool = False - sso_enforced_for_web_protocol: str = '' + sso_enforced_for_web_protocol: str = "" + enable_web_sso_switch_component: bool = False class FeatureService: - @classmethod def get_features(cls, tenant_id: str) -> FeatureModel: features = FeatureModel() @@ -61,6 +61,7 @@ class FeatureService: system_features = SystemFeatureModel() if dify_config.ENTERPRISE_ENABLED: + system_features.enable_web_sso_switch_component = True cls._fulfill_params_from_enterprise(system_features) return system_features @@ -75,44 +76,44 @@ class FeatureService: def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) - features.billing.enabled = billing_info['enabled'] - features.billing.subscription.plan = billing_info['subscription']['plan'] - features.billing.subscription.interval = billing_info['subscription']['interval'] + features.billing.enabled = billing_info["enabled"] + features.billing.subscription.plan = billing_info["subscription"]["plan"] + features.billing.subscription.interval = billing_info["subscription"]["interval"] - if 'members' in billing_info: - features.members.size = billing_info['members']['size'] - features.members.limit = billing_info['members']['limit'] + if "members" in billing_info: + features.members.size = billing_info["members"]["size"] + features.members.limit = billing_info["members"]["limit"] - if 'apps' in billing_info: - features.apps.size = billing_info['apps']['size'] - features.apps.limit = billing_info['apps']['limit'] + if "apps" in billing_info: + features.apps.size = billing_info["apps"]["size"] + features.apps.limit = billing_info["apps"]["limit"] - if 'vector_space' in billing_info: - features.vector_space.size = billing_info['vector_space']['size'] - features.vector_space.limit = billing_info['vector_space']['limit'] + if "vector_space" in billing_info: + features.vector_space.size = billing_info["vector_space"]["size"] + features.vector_space.limit = billing_info["vector_space"]["limit"] - if 'documents_upload_quota' in billing_info: - features.documents_upload_quota.size = billing_info['documents_upload_quota']['size'] - features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit'] + if "documents_upload_quota" in billing_info: + features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"] + features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"] - if 'annotation_quota_limit' in billing_info: - features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size'] - features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit'] + if "annotation_quota_limit" in billing_info: + features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"] + features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"] - if 'docs_processing' in billing_info: - features.docs_processing = billing_info['docs_processing'] + if "docs_processing" in billing_info: + features.docs_processing = billing_info["docs_processing"] - if 'can_replace_logo' in billing_info: - features.can_replace_logo = billing_info['can_replace_logo'] + if "can_replace_logo" in billing_info: + features.can_replace_logo = billing_info["can_replace_logo"] - if 'model_load_balancing_enabled' in billing_info: - features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled'] + if "model_load_balancing_enabled" in billing_info: + features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] @classmethod def _fulfill_params_from_enterprise(cls, features): enterprise_info = EnterpriseService.get_info() - features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] - features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] - features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web'] - features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol'] + features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] + features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] + features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] + features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] diff --git a/api/services/file_service.py b/api/services/file_service.py index 9139962240..5780abb2be 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -17,27 +17,45 @@ from models.account import Account from models.model import EndUser, UploadFile from services.errors.file import FileTooLargeError, UnsupportedFileTypeError -IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] +IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) -ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', 'docx', 'csv'] -UNSTRUCTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', - 'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml', 'epub'] +ALLOWED_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"] +UNSTRUCTURED_ALLOWED_EXTENSIONS = [ + "txt", + "markdown", + "md", + "pdf", + "html", + "htm", + "xlsx", + "xls", + "docx", + "csv", + "eml", + "msg", + "pptx", + "ppt", + "xml", + "epub", +] PREVIEW_WORDS_LIMIT = 3000 class FileService: - @staticmethod def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: filename = file.filename - extension = file.filename.split('.')[-1] + extension = file.filename.split(".")[-1] if len(filename) > 200: - filename = filename.split('.')[0][:200] + '.' + extension + filename = filename.split(".")[0][:200] + "." + extension etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \ + allowed_extensions = ( + UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS + if etl_type == "Unstructured" else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS + ) if extension.lower() not in allowed_extensions: raise UnsupportedFileTypeError() elif only_image and extension.lower() not in IMAGE_EXTENSIONS: @@ -55,7 +73,7 @@ class FileService: file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 if file_size > file_size_limit: - message = f'File size exceeded. {file_size} > {file_size_limit}' + message = f"File size exceeded. {file_size} > {file_size_limit}" raise FileTooLargeError(message) # user uuid as file name @@ -67,7 +85,7 @@ class FileService: # end_user current_tenant_id = user.tenant_id - file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension + file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension # save file to storage storage.save(file_key, file_content) @@ -81,11 +99,11 @@ class FileService: size=file_size, extension=extension, mime_type=file.mimetype, - created_by_role=('account' if isinstance(user, Account) else 'end_user'), + created_by_role=("account" if isinstance(user, Account) else "end_user"), created_by=user.id, created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=False, - hash=hashlib.sha3_256(file_content).hexdigest() + hash=hashlib.sha3_256(file_content).hexdigest(), ) db.session.add(upload_file) @@ -99,10 +117,10 @@ class FileService: text_name = text_name[:200] # user uuid as file name file_uuid = str(uuid.uuid4()) - file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt' + file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt" # save file to storage - storage.save(file_key, text.encode('utf-8')) + storage.save(file_key, text.encode("utf-8")) # save file to db upload_file = UploadFile( @@ -111,13 +129,13 @@ class FileService: key=file_key, name=text_name, size=len(text), - extension='txt', - mime_type='text/plain', + extension="txt", + mime_type="text/plain", created_by=current_user.id, created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=True, used_by=current_user.id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) db.session.add(upload_file) @@ -127,9 +145,7 @@ class FileService: @staticmethod def get_file_preview(file_id: str) -> str: - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found") @@ -137,12 +153,12 @@ class FileService: # extract text from file extension = upload_file.extension etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS + allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS if extension.lower() not in allowed_extensions: raise UnsupportedFileTypeError() text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) - text = text[0:PREVIEW_WORDS_LIMIT] if text else '' + text = text[0:PREVIEW_WORDS_LIMIT] if text else "" return text @@ -152,9 +168,7 @@ class FileService: if not result: raise NotFound("File not found or signature is invalid") - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") @@ -170,9 +184,7 @@ class FileService: @staticmethod def get_public_image_preview(file_id: str) -> tuple[Generator, str]: - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index de5f6994b0..db99064814 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -9,14 +9,11 @@ from models.account import Account from models.dataset import Dataset, DatasetQuery, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -27,9 +24,9 @@ class HitTestingService: return { "query": { "content": query, - "tsne_position": {'x': 0, 'y': 0}, + "tsne_position": {"x": 0, "y": 0}, }, - "records": [] + "records": [], } start = time.perf_counter() @@ -38,28 +35,28 @@ class HitTestingService: if not retrieval_model: retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model - all_documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'), - dataset_id=dataset.id, - query=cls.escape_query_for_search(query), - top_k=retrieval_model.get('top_k', 2), - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + all_documents = RetrievalService.retrieve( + retrival_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=cls.escape_query_for_search(query), + top_k=retrieval_model.get("top_k", 2), + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else None, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") + if retrieval_model.get("reranking_mode") + else "reranking_model", + weights=retrieval_model.get("weights", None), + ) end = time.perf_counter() logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") dataset_query = DatasetQuery( - dataset_id=dataset.id, - content=query, - source='hit_testing', - created_by_role='account', - created_by=account.id + dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id ) db.session.add(dataset_query) @@ -72,14 +69,18 @@ class HitTestingService: i = 0 records = [] for document in documents: - index_node_id = document.metadata['doc_id'] + index_node_id = document.metadata["doc_id"] - segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.enabled == True, - DocumentSegment.status == 'completed', - DocumentSegment.index_node_id == index_node_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + .first() + ) if not segment: i += 1 @@ -87,7 +88,7 @@ class HitTestingService: record = { "segment": segment, - "score": document.metadata.get('score', None), + "score": document.metadata.get("score", None), } records.append(record) @@ -98,15 +99,15 @@ class HitTestingService: "query": { "content": query, }, - "records": records + "records": records, } @classmethod def hit_testing_args_check(cls, args): - query = args['query'] + query = args["query"] if not query or len(query) > 250: - raise ValueError('Query is required and cannot exceed 250 characters') + raise ValueError("Query is required and cannot exceed 250 characters") @staticmethod def escape_query_for_search(query: str) -> str: diff --git a/api/services/message_service.py b/api/services/message_service.py index 491a914c77..ecb121c36e 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -27,8 +27,14 @@ from services.workflow_service import WorkflowService class MessageService: @classmethod - def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - conversation_id: str, first_id: Optional[str], limit: int) -> InfiniteScrollPagination: + def pagination_by_first_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + conversation_id: str, + first_id: Optional[str], + limit: int, + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -36,52 +42,69 @@ class MessageService: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) conversation = ConversationService.get_conversation( - app_model=app_model, - user=user, - conversation_id=conversation_id + app_model=app_model, user=user, conversation_id=conversation_id ) if first_id: - first_message = db.session.query(Message) \ - .filter(Message.conversation_id == conversation.id, Message.id == first_id).first() + first_message = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id, Message.id == first_id) + .first() + ) if not first_message: raise FirstMessageNotExistsError() - history_messages = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < first_message.created_at, - Message.id != first_message.id - ) \ - .order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < first_message.created_at, + Message.id != first_message.id, + ) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) else: - history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ - .order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) has_more = False if len(history_messages) == limit: current_page_first_message = history_messages[-1] - rest_count = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id - ).count() + rest_count = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id, + ) + .count() + ) if rest_count > 0: has_more = True history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination( - data=history_messages, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, conversation_id: Optional[str] = None, - include_ids: Optional[list] = None) -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + conversation_id: Optional[str] = None, + include_ids: Optional[list] = None, + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -89,9 +112,7 @@ class MessageService: if conversation_id is not None: conversation = ConversationService.get_conversation( - app_model=app_model, - user=user, - conversation_id=conversation_id + app_model=app_model, user=user, conversation_id=conversation_id ) base_query = base_query.filter(Message.conversation_id == conversation.id) @@ -105,10 +126,12 @@ class MessageService: if not last_message: raise LastMessageNotExistsError() - history_messages = base_query.filter( - Message.created_at < last_message.created_at, - Message.id != last_message.id - ).order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) else: history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all() @@ -116,30 +139,22 @@ class MessageService: if len(history_messages) == limit: current_page_first_message = history_messages[-1] rest_count = base_query.filter( - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id + Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id ).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=history_messages, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod - def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], - rating: Optional[str]) -> MessageFeedback: + def create_feedback( + cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str] + ) -> MessageFeedback: if not user: - raise ValueError('user cannot be None') + raise ValueError("user cannot be None") - message = cls.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = cls.get_message(app_model=app_model, user=user, message_id=message_id) feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback @@ -148,14 +163,14 @@ class MessageService: elif rating and feedback: feedback.rating = rating elif not rating and not feedback: - raise ValueError('rating cannot be None when feedback not exists') + raise ValueError("rating cannot be None when feedback not exists") else: feedback = MessageFeedback( app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, rating=rating, - from_source=('user' if isinstance(user, EndUser) else 'admin'), + from_source=("user" if isinstance(user, EndUser) else "admin"), from_end_user_id=(user.id if isinstance(user, EndUser) else None), from_account_id=(user.id if isinstance(user, Account) else None), ) @@ -167,13 +182,17 @@ class MessageService: @classmethod def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() + message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ) + .first() + ) if not message: raise MessageNotExistsError() @@ -181,27 +200,22 @@ class MessageService: return message @classmethod - def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]], - message_id: str, invoke_from: InvokeFrom) -> list[Message]: + def get_suggested_questions_after_answer( + cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom + ) -> list[Message]: if not user: - raise ValueError('user cannot be None') + raise ValueError("user cannot be None") - message = cls.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = cls.get_message(app_model=app_model, user=user, message_id=message_id) conversation = ConversationService.get_conversation( - app_model=app_model, - conversation_id=message.conversation_id, - user=user + app_model=app_model, conversation_id=message.conversation_id, user=user ) if not conversation: raise ConversationNotExistsError() - if conversation.status != 'normal': + if conversation.status != "normal": raise ConversationCompletedError() model_manager = ModelManager() @@ -216,24 +230,23 @@ class MessageService: if workflow is None: return [] - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) if not app_config.additional_features.suggested_questions_after_answer: raise SuggestedQuestionsAfterAnswerDisabledError() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.LLM + tenant_id=app_model.tenant_id, model_type=ModelType.LLM ) else: if not conversation.override_model_configs: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter( + AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id + ) + .first() + ) else: conversation_override_model_configs = json.loads(conversation.override_model_configs) app_model_config = AppModelConfig( @@ -249,16 +262,13 @@ class MessageService: model_instance = model_manager.get_model_instance( tenant_id=app_model.tenant_id, - provider=app_model_config.model_dict['provider'], + provider=app_model_config.model_dict["provider"], model_type=ModelType.LLM, - model=app_model_config.model_dict['name'] + model=app_model_config.model_dict["name"], ) # get memory of conversation (read-only) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) histories = memory.get_history_prompt_text( max_token_limit=3000, @@ -267,18 +277,14 @@ class MessageService: with measure_time() as timer: questions = LLMGenerator.generate_suggested_questions_after_answer( - tenant_id=app_model.tenant_id, - histories=histories + tenant_id=app_model.tenant_id, histories=histories ) # get tracing instance trace_manager = TraceQueueManager(app_id=app_model.id) trace_manager.add_trace_task( TraceTask( - TraceTaskName.SUGGESTED_QUESTION_TRACE, - message_id=message_id, - suggested_question=questions, - timer=timer + TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer ) ) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 80eb72140d..e7b9422cfe 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -23,7 +23,6 @@ logger = logging.getLogger(__name__) class ModelLoadBalancingService: - def __init__(self) -> None: self.provider_manager = ProviderManager() @@ -46,10 +45,7 @@ class ModelLoadBalancingService: raise ValueError(f"Provider {provider} does not exist.") # Enable model load balancing - provider_configuration.enable_model_load_balancing( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ @@ -70,13 +66,11 @@ class ModelLoadBalancingService: raise ValueError(f"Provider {provider} does not exist.") # disable model load balancing - provider_configuration.disable_model_load_balancing( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) - def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \ - -> tuple[bool, list[dict]]: + def get_load_balancing_configs( + self, tenant_id: str, provider: str, model: str, model_type: str + ) -> tuple[bool, list[dict]]: """ Get load balancing configurations. :param tenant_id: workspace id @@ -107,20 +101,24 @@ class ModelLoadBalancingService: is_load_balancing_enabled = True # Get load balancing configurations - load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).order_by(LoadBalancingModelConfig.created_at).all() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .order_by(LoadBalancingModelConfig.created_at) + .all() + ) if provider_configuration.custom_configuration.provider: # check if the inherit configuration exists, # inherit is represented for the provider or model custom credentials inherit_config_exists = False for load_balancing_config in load_balancing_configs: - if load_balancing_config.name == '__inherit__': + if load_balancing_config.name == "__inherit__": inherit_config_exists = True break @@ -133,7 +131,7 @@ class ModelLoadBalancingService: else: # move the inherit configuration to the first for i, load_balancing_config in enumerate(load_balancing_configs[:]): - if load_balancing_config.name == '__inherit__': + if load_balancing_config.name == "__inherit__": inherit_config = load_balancing_configs.pop(i) load_balancing_configs.insert(0, inherit_config) @@ -151,7 +149,7 @@ class ModelLoadBalancingService: provider=provider, model=model, model_type=model_type, - config_id=load_balancing_config.id + config_id=load_balancing_config.id, ) try: @@ -172,32 +170,32 @@ class ModelLoadBalancingService: if variable in credentials: try: credentials[variable] = encrypter.decrypt_token_with_decoding( - credentials.get(variable), - decoding_rsa_key, - decoding_cipher_rsa + credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa ) except ValueError: pass # Obfuscate credentials credentials = provider_configuration.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=credential_schemas.credential_form_schemas + credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas ) - datas.append({ - 'id': load_balancing_config.id, - 'name': load_balancing_config.name, - 'credentials': credentials, - 'enabled': load_balancing_config.enabled, - 'in_cooldown': in_cooldown, - 'ttl': ttl - }) + datas.append( + { + "id": load_balancing_config.id, + "name": load_balancing_config.name, + "credentials": credentials, + "enabled": load_balancing_config.enabled, + "in_cooldown": in_cooldown, + "ttl": ttl, + } + ) return is_load_balancing_enabled, datas - def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \ - -> Optional[dict]: + def get_load_balancing_config( + self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str + ) -> Optional[dict]: """ Get load balancing configuration. :param tenant_id: workspace id @@ -219,14 +217,17 @@ class ModelLoadBalancingService: model_type = ModelType.value_of(model_type) # Get load balancing configurations - load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + load_balancing_model_config = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model, - LoadBalancingModelConfig.id == config_id - ).first() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id, + ) + .first() + ) if not load_balancing_model_config: return None @@ -244,19 +245,19 @@ class ModelLoadBalancingService: # Obfuscate credentials credentials = provider_configuration.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=credential_schemas.credential_form_schemas + credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas ) return { - 'id': load_balancing_model_config.id, - 'name': load_balancing_model_config.name, - 'credentials': credentials, - 'enabled': load_balancing_model_config.enabled + "id": load_balancing_model_config.id, + "name": load_balancing_model_config.name, + "credentials": credentials, + "enabled": load_balancing_model_config.enabled, } - def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \ - -> LoadBalancingModelConfig: + def _init_inherit_config( + self, tenant_id: str, provider: str, model: str, model_type: ModelType + ) -> LoadBalancingModelConfig: """ Initialize the inherit configuration. :param tenant_id: workspace id @@ -271,18 +272,16 @@ class ModelLoadBalancingService: provider_name=provider, model_type=model_type.to_origin_model_type(), model_name=model, - name='__inherit__' + name="__inherit__", ) db.session.add(inherit_config) db.session.commit() return inherit_config - def update_load_balancing_configs(self, tenant_id: str, - provider: str, - model: str, - model_type: str, - configs: list[dict]) -> None: + def update_load_balancing_configs( + self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict] + ) -> None: """ Update load balancing configurations. :param tenant_id: workspace id @@ -304,15 +303,18 @@ class ModelLoadBalancingService: model_type = ModelType.value_of(model_type) if not isinstance(configs, list): - raise ValueError('Invalid load balancing configs') + raise ValueError("Invalid load balancing configs") - current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + current_load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).all() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .all() + ) # id as key, config as value current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} @@ -320,25 +322,25 @@ class ModelLoadBalancingService: for config in configs: if not isinstance(config, dict): - raise ValueError('Invalid load balancing config') + raise ValueError("Invalid load balancing config") - config_id = config.get('id') - name = config.get('name') - credentials = config.get('credentials') - enabled = config.get('enabled') + config_id = config.get("id") + name = config.get("name") + credentials = config.get("credentials") + enabled = config.get("enabled") if not name: - raise ValueError('Invalid load balancing config name') + raise ValueError("Invalid load balancing config name") if enabled is None: - raise ValueError('Invalid load balancing config enabled') + raise ValueError("Invalid load balancing config enabled") # is config exists if config_id: config_id = str(config_id) if config_id not in current_load_balancing_configs_dict: - raise ValueError('Invalid load balancing config id: {}'.format(config_id)) + raise ValueError("Invalid load balancing config id: {}".format(config_id)) updated_config_ids.add(config_id) @@ -347,11 +349,11 @@ class ModelLoadBalancingService: # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: - raise ValueError('Load balancing config name {} already exists'.format(name)) + raise ValueError("Load balancing config name {} already exists".format(name)) if credentials: if not isinstance(credentials, dict): - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") # validate custom provider config credentials = self._custom_credentials_validate( @@ -361,7 +363,7 @@ class ModelLoadBalancingService: model=model, credentials=credentials, load_balancing_model_config=load_balancing_config, - validate=False + validate=False, ) # update load balancing config @@ -375,19 +377,19 @@ class ModelLoadBalancingService: self._clear_credentials_cache(tenant_id, config_id) else: # create load balancing config - if name == '__inherit__': - raise ValueError('Invalid load balancing config name') + if name == "__inherit__": + raise ValueError("Invalid load balancing config name") # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.name == name: - raise ValueError('Load balancing config name {} already exists'.format(name)) + raise ValueError("Load balancing config name {} already exists".format(name)) if not credentials: - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") if not isinstance(credentials, dict): - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") # validate custom provider config credentials = self._custom_credentials_validate( @@ -396,7 +398,7 @@ class ModelLoadBalancingService: model_type=model_type, model=model, credentials=credentials, - validate=False + validate=False, ) # create load balancing config @@ -406,7 +408,7 @@ class ModelLoadBalancingService: model_type=model_type.to_origin_model_type(), model_name=model, name=name, - encrypted_config=json.dumps(credentials) + encrypted_config=json.dumps(credentials), ) db.session.add(load_balancing_model_config) @@ -420,12 +422,15 @@ class ModelLoadBalancingService: self._clear_credentials_cache(tenant_id, config_id) - def validate_load_balancing_credentials(self, tenant_id: str, - provider: str, - model: str, - model_type: str, - credentials: dict, - config_id: Optional[str] = None) -> None: + def validate_load_balancing_credentials( + self, + tenant_id: str, + provider: str, + model: str, + model_type: str, + credentials: dict, + config_id: Optional[str] = None, + ) -> None: """ Validate load balancing credentials. :param tenant_id: workspace id @@ -450,14 +455,17 @@ class ModelLoadBalancingService: load_balancing_model_config = None if config_id: # Get load balancing config - load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + load_balancing_model_config = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model, - LoadBalancingModelConfig.id == config_id - ).first() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id, + ) + .first() + ) if not load_balancing_model_config: raise ValueError(f"Load balancing config {config_id} does not exist.") @@ -469,16 +477,19 @@ class ModelLoadBalancingService: model_type=model_type, model=model, credentials=credentials, - load_balancing_model_config=load_balancing_model_config + load_balancing_model_config=load_balancing_model_config, ) - def _custom_credentials_validate(self, tenant_id: str, - provider_configuration: ProviderConfiguration, - model_type: ModelType, - model: str, - credentials: dict, - load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, - validate: bool = True) -> dict: + def _custom_credentials_validate( + self, + tenant_id: str, + provider_configuration: ProviderConfiguration, + model_type: ModelType, + model: str, + credentials: dict, + load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, + validate: bool = True, + ) -> dict: """ Validate custom credentials. :param tenant_id: workspace id @@ -521,12 +532,11 @@ class ModelLoadBalancingService: provider=provider_configuration.provider.provider, model_type=model_type, model=model, - credentials=credentials + credentials=credentials, ) else: credentials = model_provider_factory.provider_credentials_validate( - provider=provider_configuration.provider.provider, - credentials=credentials + provider=provider_configuration.provider.provider, credentials=credentials ) for key, value in credentials.items(): @@ -535,8 +545,9 @@ class ModelLoadBalancingService: return credentials - def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \ - -> ModelCredentialSchema | ProviderCredentialSchema: + def _get_credential_schema( + self, provider_configuration: ProviderConfiguration + ) -> ModelCredentialSchema | ProviderCredentialSchema: """ Get form schemas. :param provider_configuration: provider configuration @@ -558,9 +569,7 @@ class ModelLoadBalancingService: :return: """ provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=config_id, - cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL ) provider_model_credentials_cache.delete() diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 86c059fe90..c0f3c40762 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -73,8 +73,8 @@ class ModelProviderService: system_configuration=SystemConfigurationResponse( enabled=provider_configuration.system_configuration.enabled, current_quota_type=provider_configuration.system_configuration.current_quota_type, - quota_configurations=provider_configuration.system_configuration.quota_configurations - ) + quota_configurations=provider_configuration.system_configuration.quota_configurations, + ), ) provider_responses.append(provider_response) @@ -95,9 +95,9 @@ class ModelProviderService: provider_configurations = self.provider_manager.get_configurations(tenant_id) # Get provider available models - return [ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models( - provider=provider - )] + return [ + ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider) + ] def get_provider_credentials(self, tenant_id: str, provider: str) -> dict: """ @@ -195,13 +195,12 @@ class ModelProviderService: # Get model custom credentials from ProviderModel if exists return provider_configuration.get_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model, - obfuscated=True + model_type=ModelType.value_of(model_type), model=model, obfuscated=True ) - def model_credentials_validate(self, tenant_id: str, provider: str, model_type: str, model: str, - credentials: dict) -> None: + def model_credentials_validate( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + ) -> None: """ validate model credentials. @@ -222,13 +221,12 @@ class ModelProviderService: # Validate model credentials provider_configuration.custom_model_credentials_validate( - model_type=ModelType.value_of(model_type), - model=model, - credentials=credentials + model_type=ModelType.value_of(model_type), model=model, credentials=credentials ) - def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, - credentials: dict) -> None: + def save_model_credentials( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + ) -> None: """ save model credentials. @@ -249,9 +247,7 @@ class ModelProviderService: # Add or update custom model credentials provider_configuration.add_or_update_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model, - credentials=credentials + model_type=ModelType.value_of(model_type), model=model, credentials=credentials ) def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: @@ -273,10 +269,7 @@ class ModelProviderService: raise ValueError(f"Provider {provider} does not exist.") # Remove custom model credentials - provider_configuration.delete_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model - ) + provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model) def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: """ @@ -290,9 +283,7 @@ class ModelProviderService: provider_configurations = self.provider_manager.get_configurations(tenant_id) # Get provider available models - models = provider_configurations.get_models( - model_type=ModelType.value_of(model_type) - ) + models = provider_configurations.get_models(model_type=ModelType.value_of(model_type)) # Group models by provider provider_models = {} @@ -323,16 +314,19 @@ class ModelProviderService: icon_small=first_model.provider.icon_small, icon_large=first_model.provider.icon_large, status=CustomConfigurationStatus.ACTIVE, - models=[ProviderModelWithStatusEntity( - model=model.model, - label=model.label, - model_type=model.model_type, - features=model.features, - fetch_from=model.fetch_from, - model_properties=model.model_properties, - status=model.status, - load_balancing_enabled=model.load_balancing_enabled - ) for model in models] + models=[ + ProviderModelWithStatusEntity( + model=model.model, + label=model.label, + model_type=model.model_type, + features=model.features, + fetch_from=model.fetch_from, + model_properties=model.model_properties, + status=model.status, + load_balancing_enabled=model.load_balancing_enabled, + ) + for model in models + ], ) ) @@ -361,19 +355,13 @@ class ModelProviderService: model_type_instance = cast(LargeLanguageModel, model_type_instance) # fetch credentials - credentials = provider_configuration.get_current_credentials( - model_type=ModelType.LLM, - model=model - ) + credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) if not credentials: return [] # Call get_parameter_rules method of model instance to get model parameter rules - return model_type_instance.get_parameter_rules( - model=model, - credentials=credentials - ) + return model_type_instance.get_parameter_rules(model=model, credentials=credentials) def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: """ @@ -384,22 +372,23 @@ class ModelProviderService: :return: """ model_type_enum = ModelType.value_of(model_type) - result = self.provider_manager.get_default_model( - tenant_id=tenant_id, - model_type=model_type_enum - ) + result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) try: - return DefaultModelResponse( - model=result.model, - model_type=result.model_type, - provider=SimpleProviderEntityResponse( - provider=result.provider.provider, - label=result.provider.label, - icon_small=result.provider.icon_small, - icon_large=result.provider.icon_large, - supported_model_types=result.provider.supported_model_types + return ( + DefaultModelResponse( + model=result.model, + model_type=result.model_type, + provider=SimpleProviderEntityResponse( + provider=result.provider.provider, + label=result.provider.label, + icon_small=result.provider.icon_small, + icon_large=result.provider.icon_large, + supported_model_types=result.provider.supported_model_types, + ), ) - ) if result else None + if result + else None + ) except Exception as e: logger.info(f"get_default_model_of_model_type error: {e}") return None @@ -416,13 +405,12 @@ class ModelProviderService: """ model_type_enum = ModelType.value_of(model_type) self.provider_manager.update_default_model_record( - tenant_id=tenant_id, - model_type=model_type_enum, - provider=provider, - model=model + tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model ) - def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[Optional[bytes], Optional[str]]: + def get_model_provider_icon( + self, provider: str, icon_type: str, lang: str + ) -> tuple[Optional[bytes], Optional[str]]: """ get model provider icon. @@ -434,11 +422,11 @@ class ModelProviderService: provider_instance = model_provider_factory.get_provider_instance(provider) provider_schema = provider_instance.get_provider_schema() - if icon_type.lower() == 'icon_small': + if icon_type.lower() == "icon_small": if not provider_schema.icon_small: raise ValueError(f"Provider {provider} does not have small icon.") - if lang.lower() == 'zh_hans': + if lang.lower() == "zh_hans": file_name = provider_schema.icon_small.zh_Hans else: file_name = provider_schema.icon_small.en_US @@ -446,13 +434,15 @@ class ModelProviderService: if not provider_schema.icon_large: raise ValueError(f"Provider {provider} does not have large icon.") - if lang.lower() == 'zh_hans': + if lang.lower() == "zh_hans": file_name = provider_schema.icon_large.zh_Hans else: file_name = provider_schema.icon_large.en_US root_path = current_app.root_path - provider_instance_path = os.path.dirname(os.path.join(root_path, provider_instance.__class__.__module__.replace('.', '/'))) + provider_instance_path = os.path.dirname( + os.path.join(root_path, provider_instance.__class__.__module__.replace(".", "/")) + ) file_path = os.path.join(provider_instance_path, "_assets") file_path = os.path.join(file_path, file_name) @@ -460,10 +450,10 @@ class ModelProviderService: return None, None mimetype, _ = mimetypes.guess_type(file_path) - mimetype = mimetype or 'application/octet-stream' + mimetype = mimetype or "application/octet-stream" # read binary from file - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: byte_data = f.read() return byte_data, mimetype @@ -509,10 +499,7 @@ class ModelProviderService: raise ValueError(f"Provider {provider} does not exist.") # Enable model - provider_configuration.enable_model( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ @@ -533,78 +520,49 @@ class ModelProviderService: raise ValueError(f"Provider {provider} does not exist.") # Enable model - provider_configuration.disable_model( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) def free_quota_submit(self, tenant_id: str, provider: str): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") - api_url = api_base_url + '/api/v1/providers/apply' + api_url = api_base_url + "/api/v1/providers/apply" - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {api_key}" - } - response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider}) + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider}) if not response.ok: logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") raise ValueError(f"Error: {response.status_code} ") - if response.json()["code"] != 'success': - raise ValueError( - f"error: {response.json()['message']}" - ) + if response.json()["code"] != "success": + raise ValueError(f"error: {response.json()['message']}") rst = response.json() - if rst['type'] == 'redirect': - return { - 'type': rst['type'], - 'redirect_url': rst['redirect_url'] - } + if rst["type"] == "redirect": + return {"type": rst["type"], "redirect_url": rst["redirect_url"]} else: - return { - 'type': rst['type'], - 'result': 'success' - } + return {"type": rst["type"], "result": "success"} def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") - api_url = api_base_url + '/api/v1/providers/qualification-verify' + api_url = api_base_url + "/api/v1/providers/qualification-verify" - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {api_key}" - } - json_data = {'workspace_id': tenant_id, 'provider_name': provider} + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + json_data = {"workspace_id": tenant_id, "provider_name": provider} if token: - json_data['token'] = token - response = requests.post(api_url, headers=headers, - json=json_data) + json_data["token"] = token + response = requests.post(api_url, headers=headers, json=json_data) if not response.ok: logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") raise ValueError(f"Error: {response.status_code} ") rst = response.json() - if rst["code"] != 'success': - raise ValueError( - f"error: {rst['message']}" - ) + if rst["code"] != "success": + raise ValueError(f"error: {rst['message']}") - data = rst['data'] - if data['qualified'] is True: - return { - 'result': 'success', - 'provider_name': provider, - 'flag': True - } + data = rst["data"] + if data["qualified"] is True: + return {"result": "success", "provider_name": provider, "flag": True} else: - return { - 'result': 'success', - 'provider_name': provider, - 'flag': False, - 'reason': data['reason'] - } + return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]} diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index d472f8cfbc..dfb21e767f 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -4,17 +4,18 @@ from models.model import App, AppModelConfig class ModerationService: - def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: app_model_config: AppModelConfig = None - app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + app_model_config = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + ) if not app_model_config: raise ValueError("app model config not found") - name = app_model_config.sensitive_word_avoidance_dict['type'] - config = app_model_config.sensitive_word_avoidance_dict['config'] + name = app_model_config.sensitive_word_avoidance_dict["type"] + config = app_model_config.sensitive_word_avoidance_dict["config"] moderation = ModerationFactory(name, app_id, app_model.tenant_id, config) return moderation.moderation_for_outputs(text) diff --git a/api/services/operation_service.py b/api/services/operation_service.py index 39f249dc24..8c8b64bcd5 100644 --- a/api/services/operation_service.py +++ b/api/services/operation_service.py @@ -4,15 +4,12 @@ import requests class OperationService: - base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') - secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') + base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") + secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @classmethod def _send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Billing-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) @@ -22,11 +19,11 @@ class OperationService: @classmethod def record_utm(cls, tenant_id: str, utm_info: dict): params = { - 'tenant_id': tenant_id, - 'utm_source': utm_info.get('utm_source', ''), - 'utm_medium': utm_info.get('utm_medium', ''), - 'utm_campaign': utm_info.get('utm_campaign', ''), - 'utm_content': utm_info.get('utm_content', ''), - 'utm_term': utm_info.get('utm_term', '') + "tenant_id": tenant_id, + "utm_source": utm_info.get("utm_source", ""), + "utm_medium": utm_info.get("utm_medium", ""), + "utm_campaign": utm_info.get("utm_campaign", ""), + "utm_content": utm_info.get("utm_content", ""), + "utm_term": utm_info.get("utm_term", ""), } - return cls._send_request('POST', '/tenant_utms', params=params) + return cls._send_request("POST", "/tenant_utms", params=params) diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 7b2edcf7cb..0650f2cb27 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -12,19 +12,25 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config_data: return None # decrypt_token and obfuscated_token tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id - decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, trace_config_data.tracing_config) - if tracing_provider == 'langfuse' and ('project_key' not in decrypt_tracing_config or not decrypt_tracing_config.get('project_key')): + decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( + tenant_id, tracing_provider, trace_config_data.tracing_config + ) + if tracing_provider == "langfuse" and ( + "project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key") + ): project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) - decrypt_tracing_config['project_key'] = project_key + decrypt_tracing_config["project_key"] = project_key decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config) @@ -44,8 +50,10 @@ class OpsService: if tracing_provider not in provider_config_map.keys() and tracing_provider: return {"error": f"Invalid tracing provider: {tracing_provider}"} - config_class, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['other_keys'] + config_class, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["other_keys"], + ) default_config_instance = config_class(**tracing_config) for key in other_keys: if key in tracing_config and tracing_config[key] == "": @@ -59,9 +67,11 @@ class OpsService: project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider) # check if trace config already exists - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if trace_config_data: return None @@ -69,8 +79,8 @@ class OpsService: # get tenant id tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) - if tracing_provider == 'langfuse': - tracing_config['project_key'] = project_key + if tracing_provider == "langfuse": + tracing_config["project_key"] = project_key trace_config_data = TraceAppConfig( app_id=app_id, tracing_provider=tracing_provider, @@ -94,9 +104,11 @@ class OpsService: raise ValueError(f"Invalid tracing provider: {tracing_provider}") # check if trace config already exists - current_trace_config = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + current_trace_config = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not current_trace_config: return None @@ -126,9 +138,11 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config: return None diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 1c1c5be17c..10abf0a764 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -16,7 +16,6 @@ logger = logging.getLogger(__name__) class RecommendedAppService: - builtin_data: Optional[dict] = None @classmethod @@ -27,21 +26,21 @@ class RecommendedAppService: :return: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - if mode == 'remote': + if mode == "remote": try: result = cls._fetch_recommended_apps_from_dify_official(language) except Exception as e: - logger.warning(f'fetch recommended apps from dify official failed: {e}, switch to built-in.') + logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.") result = cls._fetch_recommended_apps_from_builtin(language) - elif mode == 'db': + elif mode == "db": result = cls._fetch_recommended_apps_from_db(language) - elif mode == 'builtin': + elif mode == "builtin": result = cls._fetch_recommended_apps_from_builtin(language) else: - raise ValueError(f'invalid fetch recommended apps mode: {mode}') + raise ValueError(f"invalid fetch recommended apps mode: {mode}") - if not result.get('recommended_apps') and language != 'en-US': - result = cls._fetch_recommended_apps_from_builtin('en-US') + if not result.get("recommended_apps") and language != "en-US": + result = cls._fetch_recommended_apps_from_builtin("en-US") return result @@ -52,16 +51,18 @@ class RecommendedAppService: :param language: language :return: """ - recommended_apps = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.language == language - ).all() + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) + .all() + ) if len(recommended_apps) == 0: - recommended_apps = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.language == languages[0] - ).all() + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + .all() + ) categories = set() recommended_apps_result = [] @@ -75,28 +76,28 @@ class RecommendedAppService: continue recommended_app_result = { - 'id': recommended_app.id, - 'app': { - 'id': app.id, - 'name': app.name, - 'mode': app.mode, - 'icon': app.icon, - 'icon_background': app.icon_background + "id": recommended_app.id, + "app": { + "id": app.id, + "name": app.name, + "mode": app.mode, + "icon": app.icon, + "icon_background": app.icon_background, }, - 'app_id': recommended_app.app_id, - 'description': site.description, - 'copyright': site.copyright, - 'privacy_policy': site.privacy_policy, - 'custom_disclaimer': site.custom_disclaimer, - 'category': recommended_app.category, - 'position': recommended_app.position, - 'is_listed': recommended_app.is_listed + "app_id": recommended_app.app_id, + "description": site.description, + "copyright": site.copyright, + "privacy_policy": site.privacy_policy, + "custom_disclaimer": site.custom_disclaimer, + "category": recommended_app.category, + "position": recommended_app.position, + "is_listed": recommended_app.is_listed, } recommended_apps_result.append(recommended_app_result) categories.add(recommended_app.category) # add category to categories - return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)} + return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} @classmethod def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: @@ -106,16 +107,16 @@ class RecommendedAppService: :return: """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f'{domain}/apps?language={language}' + url = f"{domain}/apps?language={language}" response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: - raise ValueError(f'fetch recommended apps failed, status code: {response.status_code}') + raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") result = response.json() if "categories" in result: result["categories"] = sorted(result["categories"]) - + return result @classmethod @@ -126,7 +127,7 @@ class RecommendedAppService: :return: """ builtin_data = cls._get_builtin_data() - return builtin_data.get('recommended_apps', {}).get(language) + return builtin_data.get("recommended_apps", {}).get(language) @classmethod def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: @@ -136,18 +137,18 @@ class RecommendedAppService: :return: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - if mode == 'remote': + if mode == "remote": try: result = cls._fetch_recommended_app_detail_from_dify_official(app_id) except Exception as e: - logger.warning(f'fetch recommended app detail from dify official failed: {e}, switch to built-in.') + logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.") result = cls._fetch_recommended_app_detail_from_builtin(app_id) - elif mode == 'db': + elif mode == "db": result = cls._fetch_recommended_app_detail_from_db(app_id) - elif mode == 'builtin': + elif mode == "builtin": result = cls._fetch_recommended_app_detail_from_builtin(app_id) else: - raise ValueError(f'invalid fetch recommended app detail mode: {mode}') + raise ValueError(f"invalid fetch recommended app detail mode: {mode}") return result @@ -159,7 +160,7 @@ class RecommendedAppService: :return: """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f'{domain}/apps/{app_id}' + url = f"{domain}/apps/{app_id}" response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: return None @@ -174,10 +175,11 @@ class RecommendedAppService: :return: """ # is in public recommended list - recommended_app = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.app_id == app_id - ).first() + recommended_app = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) + .first() + ) if not recommended_app: return None @@ -188,12 +190,12 @@ class RecommendedAppService: return None return { - 'id': app_model.id, - 'name': app_model.name, - 'icon': app_model.icon, - 'icon_background': app_model.icon_background, - 'mode': app_model.mode, - 'export_data': AppDslService.export_dsl(app_model=app_model) + "id": app_model.id, + "name": app_model.name, + "icon": app_model.icon, + "icon_background": app_model.icon_background, + "mode": app_model.mode, + "export_data": AppDslService.export_dsl(app_model=app_model), } @classmethod @@ -204,7 +206,7 @@ class RecommendedAppService: :return: """ builtin_data = cls._get_builtin_data() - return builtin_data.get('app_details', {}).get(app_id) + return builtin_data.get("app_details", {}).get(app_id) @classmethod def _get_builtin_data(cls) -> dict: @@ -216,7 +218,7 @@ class RecommendedAppService: return cls.builtin_data root_path = current_app.root_path - with open(path.join(root_path, 'constants', 'recommended_apps.json'), encoding='utf-8') as f: + with open(path.join(root_path, "constants", "recommended_apps.json"), encoding="utf-8") as f: json_data = f.read() data = json.loads(json_data) cls.builtin_data = data @@ -229,27 +231,24 @@ class RecommendedAppService: Fetch all recommended apps and export datas :return: """ - templates = { - "recommended_apps": {}, - "app_details": {} - } + templates = {"recommended_apps": {}, "app_details": {}} for language in languages: try: result = cls._fetch_recommended_apps_from_dify_official(language) except Exception as e: - logger.warning(f'fetch recommended apps from dify official failed: {e}, skip.') + logger.warning(f"fetch recommended apps from dify official failed: {e}, skip.") continue - templates['recommended_apps'][language] = result + templates["recommended_apps"][language] = result - for recommended_app in result.get('recommended_apps'): - app_id = recommended_app.get('app_id') + for recommended_app in result.get("recommended_apps"): + app_id = recommended_app.get("app_id") # get app detail app_detail = cls._fetch_recommended_app_detail_from_dify_official(app_id) if not app_detail: continue - templates['app_details'][app_id] = app_detail + templates["app_details"][app_id] = app_detail return templates diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index f1113c1505..9fe3cecce7 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -10,46 +10,48 @@ from services.message_service import MessageService class SavedMessageService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int) -> InfiniteScrollPagination: - saved_messages = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).order_by(SavedMessage.created_at.desc()).all() + def pagination_by_last_id( + cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int + ) -> InfiniteScrollPagination: + saved_messages = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .order_by(SavedMessage.created_at.desc()) + .all() + ) message_ids = [sm.message_id for sm in saved_messages] return MessageService.pagination_by_last_id( - app_model=app_model, - user=user, - last_id=last_id, - limit=limit, - include_ids=message_ids + app_model=app_model, user=user, last_id=last_id, limit=limit, include_ids=message_ids ) @classmethod def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - saved_message = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.message_id == message_id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).first() + saved_message = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .first() + ) if saved_message: return - message = MessageService.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = MessageService.get_message(app_model=app_model, user=user, message_id=message_id) saved_message = SavedMessage( app_id=app_model.id, message_id=message.id, - created_by_role='account' if isinstance(user, Account) else 'end_user', - created_by=user.id + created_by_role="account" if isinstance(user, Account) else "end_user", + created_by=user.id, ) db.session.add(saved_message) @@ -57,12 +59,16 @@ class SavedMessageService: @classmethod def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - saved_message = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.message_id == message_id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).first() + saved_message = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .first() + ) if not saved_message: return diff --git a/api/services/tag_service.py b/api/services/tag_service.py index d6eba38fbd..0c17485a9f 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -12,38 +12,32 @@ from models.model import App, Tag, TagBinding class TagService: @staticmethod def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list: - query = db.session.query( - Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count') - ).outerjoin( - TagBinding, Tag.id == TagBinding.tag_id - ).filter( - Tag.type == tag_type, - Tag.tenant_id == current_tenant_id + query = ( + db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) + .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) + .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%'))) - query = query.group_by( - Tag.id - ) + query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) + query = query.group_by(Tag.id) results = query.order_by(Tag.created_at.desc()).all() return results @staticmethod def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: - tags = db.session.query(Tag).filter( - Tag.id.in_(tag_ids), - Tag.tenant_id == current_tenant_id, - Tag.type == tag_type - ).all() + tags = ( + db.session.query(Tag) + .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + .all() + ) if not tags: return [] tag_ids = [tag.id for tag in tags] - tag_bindings = db.session.query( - TagBinding.target_id - ).filter( - TagBinding.tag_id.in_(tag_ids), - TagBinding.tenant_id == current_tenant_id - ).all() + tag_bindings = ( + db.session.query(TagBinding.target_id) + .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) + .all() + ) if not tag_bindings: return [] results = [tag_binding.target_id for tag_binding in tag_bindings] @@ -51,27 +45,28 @@ class TagService: @staticmethod def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list: - tags = db.session.query(Tag).join( - TagBinding, - Tag.id == TagBinding.tag_id - ).filter( - TagBinding.target_id == target_id, - TagBinding.tenant_id == current_tenant_id, - Tag.tenant_id == current_tenant_id, - Tag.type == tag_type - ).all() + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == target_id, + TagBinding.tenant_id == current_tenant_id, + Tag.tenant_id == current_tenant_id, + Tag.type == tag_type, + ) + .all() + ) return tags if tags else [] - @staticmethod def save_tags(args: dict) -> Tag: tag = Tag( id=str(uuid.uuid4()), - name=args['name'], - type=args['type'], + name=args["name"], + type=args["type"], created_by=current_user.id, - tenant_id=current_user.current_tenant_id + tenant_id=current_user.current_tenant_id, ) db.session.add(tag) db.session.commit() @@ -82,7 +77,7 @@ class TagService: tag = db.session.query(Tag).filter(Tag.id == tag_id).first() if not tag: raise NotFound("Tag not found") - tag.name = args['name'] + tag.name = args["name"] db.session.commit() return tag @@ -107,20 +102,21 @@ class TagService: @staticmethod def save_tag_binding(args): # check if target exists - TagService.check_target_exists(args['type'], args['target_id']) + TagService.check_target_exists(args["type"], args["target_id"]) # save tag binding - for tag_id in args['tag_ids']: - tag_binding = db.session.query(TagBinding).filter( - TagBinding.tag_id == tag_id, - TagBinding.target_id == args['target_id'] - ).first() + for tag_id in args["tag_ids"]: + tag_binding = ( + db.session.query(TagBinding) + .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) + .first() + ) if tag_binding: continue new_tag_binding = TagBinding( tag_id=tag_id, - target_id=args['target_id'], + target_id=args["target_id"], tenant_id=current_user.current_tenant_id, - created_by=current_user.id + created_by=current_user.id, ) db.session.add(new_tag_binding) db.session.commit() @@ -128,34 +124,34 @@ class TagService: @staticmethod def delete_tag_binding(args): # check if target exists - TagService.check_target_exists(args['type'], args['target_id']) + TagService.check_target_exists(args["type"], args["target_id"]) # delete tag binding - tag_bindings = db.session.query(TagBinding).filter( - TagBinding.target_id == args['target_id'], - TagBinding.tag_id == (args['tag_id']) - ).first() + tag_bindings = ( + db.session.query(TagBinding) + .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) + .first() + ) if tag_bindings: db.session.delete(tag_bindings) db.session.commit() - - @staticmethod def check_target_exists(type: str, target_id: str): - if type == 'knowledge': - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == current_user.current_tenant_id, - Dataset.id == target_id - ).first() + if type == "knowledge": + dataset = ( + db.session.query(Dataset) + .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) + .first() + ) if not dataset: raise NotFound("Dataset not found") - elif type == 'app': - app = db.session.query(App).filter( - App.tenant_id == current_user.current_tenant_id, - App.id == target_id - ).first() + elif type == "app": + app = ( + db.session.query(App) + .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id) + .first() + ) if not app: raise NotFound("App not found") else: raise NotFound("Invalid binding type") - diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index ecc065d521..3ded9c0989 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -29,111 +29,107 @@ class ApiToolManageService: @staticmethod def parser_api_schema(schema: str) -> list[ApiToolBundle]: """ - parse api schema to tool bundle + parse api schema to tool bundle """ try: warnings = {} try: tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') - + raise ValueError(f"invalid schema: {str(e)}") + credentials_schema = [ ToolProviderCredentials( - name='auth_type', + name="auth_type", type=ToolProviderCredentials.CredentialsType.SELECT, required=True, - default='none', + default="none", options=[ - ToolCredentialsOption(value='none', label=I18nObject( - en_US='None', - zh_Hans='无' - )), - ToolCredentialsOption(value='api_key', label=I18nObject( - en_US='Api Key', - zh_Hans='Api Key' - )), + ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), + ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")), ], - placeholder=I18nObject( - en_US='Select auth type', - zh_Hans='选择认证方式' - ) + placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"), ), ToolProviderCredentials( - name='api_key_header', + name="api_key_header", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, required=False, - placeholder=I18nObject( - en_US='Enter api key header', - zh_Hans='输入 api key header,如:X-API-KEY' - ), - default='api_key', - help=I18nObject( - en_US='HTTP header name for api key', - zh_Hans='HTTP 头部字段名,用于传递 api key' - ) + placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"), + default="api_key", + help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"), ), ToolProviderCredentials( - name='api_key_value', + name="api_key_value", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, required=False, - placeholder=I18nObject( - en_US='Enter api key', - zh_Hans='输入 api key' - ), - default='' + placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"), + default="", ), ] - return jsonable_encoder({ - 'schema_type': schema_type, - 'parameters_schema': tool_bundles, - 'credentials_schema': credentials_schema, - 'warning': warnings - }) + return jsonable_encoder( + { + "schema_type": schema_type, + "parameters_schema": tool_bundles, + "credentials_schema": credentials_schema, + "warning": warnings, + } + ) except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') + raise ValueError(f"invalid schema: {str(e)}") @staticmethod def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]: """ - convert schema to tool bundles + convert schema to tool bundles - :return: the list of tool bundles, description + :return: the list of tool bundles, description """ try: tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) return tool_bundles except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') + raise ValueError(f"invalid schema: {str(e)}") @staticmethod def create_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] + user_id: str, + tenant_id: str, + provider_name: str, + icon: dict, + credentials: dict, + schema_type: str, + schema: str, + privacy_policy: str, + custom_disclaimer: str, + labels: list[str], ): """ - create api tool provider + create api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema}') - + raise ValueError(f"invalid schema type {schema}") + # check if the provider exists - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if provider is not None: - raise ValueError(f'provider {provider_name} already exists') + raise ValueError(f"provider {provider_name} already exists") # parse openapi to tool bundle extra_info = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - + if len(tool_bundles) > 100: - raise ValueError('the number of apis should be less than 100') + raise ValueError("the number of apis should be less than 100") # create db provider db_provider = ApiToolProvider( @@ -142,19 +138,19 @@ class ApiToolManageService: name=provider_name, icon=json.dumps(icon), schema=schema, - description=extra_info.get('description', ''), + description=extra_info.get("description", ""), schema_type_str=schema_type, tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str={}, privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer + custom_disclaimer=custom_disclaimer, ) - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) @@ -172,14 +168,12 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_api_tool_provider_remote_schema( - user_id: str, tenant_id: str, url: str - ): + def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str): """ - get api tool provider remote schema + get api tool provider remote schema """ headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", @@ -189,84 +183,98 @@ class ApiToolManageService: try: response = get(url, headers=headers, timeout=10) if response.status_code != 200: - raise ValueError(f'Got status code {response.status_code}') + raise ValueError(f"Got status code {response.status_code}") schema = response.text # try to parse schema, avoid SSRF attack ApiToolManageService.parser_api_schema(schema) except Exception as e: logger.error(f"parse api schema error: {str(e)}") - raise ValueError('invalid schema, please check the url you provided') - - return { - 'schema': schema - } + raise ValueError("invalid schema, please check the url you provided") + + return {"schema": schema} @staticmethod - def list_api_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: + def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: """ - list api tool provider tools + list api tool provider tools """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider}') - + raise ValueError(f"you have not added provider {provider}") + controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(controller) - + return [ ToolTransformService.tool_to_user_tool( tool_bundle, labels=labels, - ) for tool_bundle in provider.tools + ) + for tool_bundle in provider.tools ] @staticmethod def update_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] + user_id: str, + tenant_id: str, + provider_name: str, + original_provider: str, + icon: dict, + credentials: dict, + schema_type: str, + schema: str, + privacy_policy: str, + custom_disclaimer: str, + labels: list[str], ): """ - update api tool provider + update api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema}') - + raise ValueError(f"invalid schema type {schema}") + # check if the provider exists - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == original_provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == original_provider, + ) + .first() + ) if provider is None: - raise ValueError(f'api provider {provider_name} does not exists') + raise ValueError(f"api provider {provider_name} does not exists") # parse openapi to tool bundle extra_info = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - + # update db provider provider.name = provider_name provider.icon = json.dumps(icon) provider.schema = schema - provider.description = extra_info.get('description', '') + provider.description = extra_info.get("description", "") provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy provider.custom_disclaimer = custom_disclaimer - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(provider, auth_type) @@ -295,84 +303,91 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def delete_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str - ): + def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + delete tool provider """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider_name}') - + raise ValueError(f"you have not added provider {provider_name}") + db.session.delete(provider) db.session.commit() - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_api_tool_provider( - user_id: str, tenant_id: str, provider: str - ): + def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): """ - get api tool provider + get api tool provider """ return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) - + @staticmethod def test_api_tool_preview( - tenant_id: str, + tenant_id: str, provider_name: str, - tool_name: str, - credentials: dict, - parameters: dict, - schema_type: str, - schema: str + tool_name: str, + credentials: dict, + parameters: dict, + schema_type: str, + schema: str, ): """ - test api tool before adding api tool provider + test api tool before adding api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema_type}') - + raise ValueError(f"invalid schema type {schema_type}") + try: tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema) except Exception as e: - raise ValueError('invalid schema') - + raise ValueError("invalid schema") + # get tool bundle tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None) if tool_bundle is None: - raise ValueError(f'invalid tool name {tool_name}') - - db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + raise ValueError(f"invalid tool name {tool_name}") + + db_provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if not db_provider: # create a fake db provider db_provider = ApiToolProvider( - tenant_id='', user_id='', name='', icon='', + tenant_id="", + user_id="", + name="", + icon="", schema=schema, - description='', + description="", schema_type_str=ApiProviderSchemaType.OPENAPI.value, tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str=json.dumps(credentials), ) - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) @@ -381,10 +396,7 @@ class ApiToolManageService: # decrypt credentials if db_provider.id: - tool_configuration = ToolConfigurationManager( - tenant_id=tenant_id, - provider_controller=provider_controller - ) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) # check if the credential has changed, save the original credential masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) @@ -396,27 +408,27 @@ class ApiToolManageService: provider_controller.validate_credentials_format(credentials) # get tool tool = provider_controller.get_tool(tool_name) - tool = tool.fork_tool_runtime(runtime={ - 'credentials': credentials, - 'tenant_id': tenant_id, - }) + tool = tool.fork_tool_runtime( + runtime={ + "credentials": credentials, + "tenant_id": tenant_id, + } + ) result = tool.validate_credentials(credentials, parameters) except Exception as e: - return { 'error': str(e) } - - return { 'result': result or 'empty response' } - + return {"error": str(e)} + + return {"result": result or "empty response"} + @staticmethod - def list_api_tools( - user_id: str, tenant_id: str - ) -> list[UserToolProvider]: + def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: """ - list api tools + list api tools """ # get all api providers - db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id - ).all() or [] + db_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or [] + ) result: list[UserToolProvider] = [] @@ -425,26 +437,21 @@ class ApiToolManageService: provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(provider_controller) user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller, - db_provider=provider, - decrypt_credentials=True + provider_controller, db_provider=provider, decrypt_credentials=True ) user_provider.labels = labels # add icon ToolTransformService.repack_provider(user_provider) - tools = provider_controller.get_tools( - user_id=user_id, tenant_id=tenant_id - ) + tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) for tool in tools: - user_provider.tools.append(ToolTransformService.tool_to_user_tool( - tenant_id=tenant_id, - tool=tool, - credentials=user_provider.original_credentials, - labels=labels - )) + user_provider.tools.append( + ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels + ) + ) result.append(user_provider) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index ebadbd9be0..dc8cebb587 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -20,21 +20,25 @@ logger = logging.getLogger(__name__) class BuiltinToolManageService: @staticmethod - def list_builtin_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: + def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: """ - list builtin tool provider tools + list builtin tool provider tools """ provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) tools = provider_controller.get_tools() - tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + tool_provider_configurations = ToolConfigurationManager( + tenant_id=tenant_id, provider_controller=provider_controller + ) # check if user has added the provider - builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() + builtin_provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .first() + ) credentials = {} if builtin_provider is not None: @@ -44,47 +48,47 @@ class BuiltinToolManageService: result = [] for tool in tools: - result.append(ToolTransformService.tool_to_user_tool( - tool=tool, - credentials=credentials, - tenant_id=tenant_id, - labels=ToolLabelManager.get_tool_labels(provider_controller) - )) + result.append( + ToolTransformService.tool_to_user_tool( + tool=tool, + credentials=credentials, + tenant_id=tenant_id, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) return result @staticmethod - def list_builtin_provider_credentials_schema( - provider_name - ): + def list_builtin_provider_credentials_schema(provider_name): """ - list builtin provider credentials schema + list builtin provider credentials schema - :return: the list of tool providers + :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name) - return jsonable_encoder([ - v for _, v in (provider.credentials_schema or {}).items() - ]) + return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) @staticmethod - def update_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str, credentials: dict - ): + def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): """ - update builtin tool provider + update builtin tool provider """ # get if the provider exists - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ) + .first() + ) try: # get provider provider_controller = ToolManager.get_builtin_provider(provider_name) if not provider_controller.need_credentials: - raise ValueError(f'provider {provider_name} does not need credentials') + raise ValueError(f"provider {provider_name} does not need credentials") tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) # get original credentials if exists if provider is not None: @@ -121,19 +125,21 @@ class BuiltinToolManageService: # delete cache tool_configuration.delete_tool_credentials_cache() - return {'result': 'success'} + return {"result": "success"} @staticmethod - def get_builtin_tool_provider_credentials( - user_id: str, tenant_id: str, provider: str - ): + def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str): """ - get builtin tool provider credentials + get builtin tool provider credentials """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .first() + ) if provider is None: return {} @@ -145,19 +151,21 @@ class BuiltinToolManageService: return credentials @staticmethod - def delete_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str - ): + def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + delete tool provider """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider_name}') + raise ValueError(f"you have not added provider {provider_name}") db.session.delete(provider) db.session.commit() @@ -167,38 +175,36 @@ class BuiltinToolManageService: tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration.delete_tool_credentials_cache() - return {'result': 'success'} + return {"result": "success"} @staticmethod - def get_builtin_tool_provider_icon( - provider: str - ): + def get_builtin_tool_provider_icon(provider: str): """ - get tool provider icon and it's mimetype + get tool provider icon and it's mimetype """ icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) - with open(icon_path, 'rb') as f: + with open(icon_path, "rb") as f: icon_bytes = f.read() return icon_bytes, mime_type @staticmethod - def list_builtin_tools( - user_id: str, tenant_id: str - ) -> list[UserToolProvider]: + def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: """ - list builtin tools + list builtin tools """ # get all builtin providers provider_controllers = ToolManager.list_builtin_providers() # get all user added providers - db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id - ).all() or [] + db_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] + ) # find provider - find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) + find_provider = lambda provider: next( + filter(lambda db_provider: db_provider.provider == provider, db_providers), None + ) result: list[UserToolProvider] = [] @@ -209,7 +215,7 @@ class BuiltinToolManageService: include_set=dify_config.POSITION_TOOL_INCLUDES_SET, exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, data=provider_controller, - name_func=lambda x: x.identity.name + name_func=lambda x: x.identity.name, ): continue @@ -217,7 +223,7 @@ class BuiltinToolManageService: user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, db_provider=find_provider(provider_controller.identity.name), - decrypt_credentials=True + decrypt_credentials=True, ) # add icon @@ -225,12 +231,14 @@ class BuiltinToolManageService: tools = provider_controller.get_tools() for tool in tools: - user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool( - tenant_id=tenant_id, - tool=tool, - credentials=user_builtin_provider.original_credentials, - labels=ToolLabelManager.get_tool_labels(provider_controller) - )) + user_builtin_provider.tools.append( + ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, + tool=tool, + credentials=user_builtin_provider.original_credentials, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) result.append(user_builtin_provider) except Exception as e: diff --git a/api/services/tools/tool_labels_service.py b/api/services/tools/tool_labels_service.py index 8a6aa025f2..35e58b5ade 100644 --- a/api/services/tools/tool_labels_service.py +++ b/api/services/tools/tool_labels_service.py @@ -5,4 +5,4 @@ from core.tools.entities.values import default_tool_labels class ToolLabelsService: @classmethod def list_tool_labels(cls) -> list[ToolLabel]: - return default_tool_labels \ No newline at end of file + return default_tool_labels diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 76d2f53ae8..1c67f7648c 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -11,13 +11,11 @@ class ToolCommonService: @staticmethod def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None): """ - list tool providers + list tool providers - :return: the list of tool providers + :return: the list of tool providers """ - providers = ToolManager.user_list_providers( - user_id, tenant_id, typ - ) + providers = ToolManager.user_list_providers(user_id, tenant_id, typ) # add icon for provider in providers: @@ -26,4 +24,3 @@ class ToolCommonService: result = [provider.to_dict() for provider in providers] return result - \ No newline at end of file diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index cfce3fbd01..6fb0f2f517 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -22,46 +22,39 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi logger = logging.getLogger(__name__) + class ToolTransformService: @staticmethod def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]: """ - get tool provider icon url + get tool provider icon url """ - url_prefix = (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/") - + url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/" + if provider_type == ToolProviderType.BUILT_IN.value: - return url_prefix + 'builtin/' + provider_name + '/icon' + return url_prefix + "builtin/" + provider_name + "/icon" elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]: try: return json.loads(icon) except: - return { - "background": "#252525", - "content": "\ud83d\ude01" - } - - return '' - + return {"background": "#252525", "content": "\ud83d\ude01"} + + return "" + @staticmethod def repack_provider(provider: Union[dict, UserToolProvider]): """ - repack provider + repack provider - :param provider: the provider dict + :param provider: the provider dict """ - if isinstance(provider, dict) and 'icon' in provider: - provider['icon'] = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider['type'], - provider_name=provider['name'], - icon=provider['icon'] + if isinstance(provider, dict) and "icon" in provider: + provider["icon"] = ToolTransformService.get_tool_provider_icon_url( + provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"] ) elif isinstance(provider, UserToolProvider): provider.icon = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider.type.value, - provider_name=provider.name, - icon=provider.icon + provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon ) @staticmethod @@ -92,14 +85,13 @@ class ToolTransformService: masked_credentials={}, is_team_authorization=False, tools=[], - labels=provider_controller.tool_labels + labels=provider_controller.tool_labels, ) # get credentials schema schema = provider_controller.get_credentials_schema() for name, value in schema.items(): - result.masked_credentials[name] = \ - ToolProviderCredentials.CredentialsType.default(value.type) + result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type) # check if the provider need credentials if not provider_controller.need_credentials: @@ -113,8 +105,7 @@ class ToolTransformService: # init tool configuration tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, - provider_controller=provider_controller + tenant_id=db_provider.tenant_id, provider_controller=provider_controller ) # decrypt the credentials and mask the credentials decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) @@ -124,7 +115,7 @@ class ToolTransformService: result.original_credentials = decrypted_credentials return result - + @staticmethod def api_provider_to_controller( db_provider: ApiToolProvider, @@ -135,25 +126,23 @@ class ToolTransformService: # package tool provider controller controller = ApiToolProviderController.from_db( db_provider=db_provider, - auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else - ApiProviderAuthType.NONE + auth_type=ApiProviderAuthType.API_KEY + if db_provider.credentials["auth_type"] == "api_key" + else ApiProviderAuthType.NONE, ) return controller - + @staticmethod - def workflow_provider_to_controller( - db_provider: WorkflowToolProvider - ) -> WorkflowToolProviderController: + def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController: """ convert provider controller to provider """ return WorkflowToolProviderController.from_db(db_provider) - + @staticmethod def workflow_provider_to_user_provider( - provider_controller: WorkflowToolProviderController, - labels: list[str] = None + provider_controller: WorkflowToolProviderController, labels: list[str] = None ): """ convert provider controller to user provider @@ -175,7 +164,7 @@ class ToolTransformService: masked_credentials={}, is_team_authorization=True, tools=[], - labels=labels or [] + labels=labels or [], ) @staticmethod @@ -183,16 +172,16 @@ class ToolTransformService: provider_controller: ApiToolProviderController, db_provider: ApiToolProvider, decrypt_credentials: bool = True, - labels: list[str] = None + labels: list[str] = None, ) -> UserToolProvider: """ convert provider controller to user provider """ - username = 'Anonymous' + username = "Anonymous" try: username = db_provider.user.name except Exception as e: - logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}') + logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}") # add provider into providers credentials = db_provider.credentials result = UserToolProvider( @@ -212,14 +201,13 @@ class ToolTransformService: masked_credentials={}, is_team_authorization=True, tools=[], - labels=labels or [] + labels=labels or [], ) if decrypt_credentials: # init tool configuration tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, - provider_controller=provider_controller + tenant_id=db_provider.tenant_id, provider_controller=provider_controller ) # decrypt the credentials and mask the credentials @@ -229,23 +217,25 @@ class ToolTransformService: result.masked_credentials = masked_credentials return result - + @staticmethod def tool_to_user_tool( - tool: Union[ApiToolBundle, WorkflowTool, Tool], - credentials: dict = None, + tool: Union[ApiToolBundle, WorkflowTool, Tool], + credentials: dict = None, tenant_id: str = None, - labels: list[str] = None + labels: list[str] = None, ) -> UserTool: """ convert tool to user tool """ if isinstance(tool, Tool): # fork tool runtime - tool = tool.fork_tool_runtime(runtime={ - 'credentials': credentials, - 'tenant_id': tenant_id, - }) + tool = tool.fork_tool_runtime( + runtime={ + "credentials": credentials, + "tenant_id": tenant_id, + } + ) # get tool parameters parameters = tool.parameters or [] @@ -270,20 +260,14 @@ class ToolTransformService: label=tool.identity.label, description=tool.description.human, parameters=current_parameters, - labels=labels + labels=labels, ) if isinstance(tool, ApiToolBundle): return UserTool( author=tool.author, name=tool.operation_id, - label=I18nObject( - en_US=tool.operation_id, - zh_Hans=tool.operation_id - ), - description=I18nObject( - en_US=tool.summary or '', - zh_Hans=tool.summary or '' - ), + label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id), + description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""), parameters=tool.parameters, - labels=labels - ) \ No newline at end of file + labels=labels, + ) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 185483a71c..3830e75339 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -19,10 +19,21 @@ class WorkflowToolManageService: """ Service class for managing workflow tools. """ + @classmethod - def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str, - label: str, icon: dict, description: str, - parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: + def create_workflow_tool( + cls, + user_id: str, + tenant_id: str, + workflow_app_id: str, + name: str, + label: str, + icon: dict, + description: str, + parameters: list[dict], + privacy_policy: str = "", + labels: list[str] = None, + ) -> dict: """ Create a workflow tool. :param user_id: the user id @@ -38,27 +49,28 @@ class WorkflowToolManageService: WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) # check if the name is unique - existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - # name or app_id - or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id) - ).first() + existing_workflow_tool_provider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + # name or app_id + or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + ) + .first() + ) if existing_workflow_tool_provider is not None: - raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists') - - app: App = db.session.query(App).filter( - App.id == workflow_app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") + + app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first() if app is None: - raise ValueError(f'App {workflow_app_id} not found') - + raise ValueError(f"App {workflow_app_id} not found") + workflow: Workflow = app.workflow if workflow is None: - raise ValueError(f'Workflow not found for app {workflow_app_id}') - + raise ValueError(f"Workflow not found for app {workflow_app_id}") + workflow_tool_provider = WorkflowToolProvider( tenant_id=tenant_id, user_id=user_id, @@ -76,19 +88,26 @@ class WorkflowToolManageService: WorkflowToolProviderController.from_db(workflow_tool_provider) except Exception as e: raise ValueError(str(e)) - + db.session.add(workflow_tool_provider) db.session.commit() - return { - 'result': 'success' - } - + return {"result": "success"} @classmethod - def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str, - name: str, label: str, icon: dict, description: str, - parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: + def update_workflow_tool( + cls, + user_id: str, + tenant_id: str, + workflow_tool_id: str, + name: str, + label: str, + icon: dict, + description: str, + parameters: list[dict], + privacy_policy: str = "", + labels: list[str] = None, + ) -> dict: """ Update a workflow tool. :param user_id: the user id @@ -106,35 +125,39 @@ class WorkflowToolManageService: WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) # check if the name is unique - existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.name == name, - WorkflowToolProvider.id != workflow_tool_id - ).first() + existing_workflow_tool_provider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.name == name, + WorkflowToolProvider.id != workflow_tool_id, + ) + .first() + ) if existing_workflow_tool_provider is not None: - raise ValueError(f'Tool with name {name} already exists') - - workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + raise ValueError(f"Tool with name {name} already exists") + + workflow_tool_provider: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if workflow_tool_provider is None: - raise ValueError(f'Tool {workflow_tool_id} not found') - - app: App = db.session.query(App).filter( - App.id == workflow_tool_provider.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_tool_id} not found") + + app: App = ( + db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() + ) if app is None: - raise ValueError(f'App {workflow_tool_provider.app_id} not found') - + raise ValueError(f"App {workflow_tool_provider.app_id} not found") + workflow: Workflow = app.workflow if workflow is None: - raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}') - + raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") + workflow_tool_provider.name = name workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) @@ -154,13 +177,10 @@ class WorkflowToolManageService: if labels is not None: ToolLabelManager.update_tool_labels( - ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), - labels + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) - return { - 'result': 'success' - } + return {"result": "success"} @classmethod def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]: @@ -170,9 +190,7 @@ class WorkflowToolManageService: :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id - ).all() + db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() tools = [] for provider in db_tools: @@ -188,14 +206,12 @@ class WorkflowToolManageService: for tool in tools: user_tool_provider = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=tool, - labels=labels.get(tool.provider_id, []) + provider_controller=tool, labels=labels.get(tool.provider_id, []) ) ToolTransformService.repack_provider(user_tool_provider) user_tool_provider.tools = [ ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=labels.get(tool.provider_id, []) + tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, []) ) ] result.append(user_tool_provider) @@ -211,15 +227,12 @@ class WorkflowToolManageService: :param workflow_app_id: the workflow app id """ db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id ).delete() db.session.commit() - return { - 'result': 'success' - } + return {"result": "success"} @classmethod def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: @@ -230,40 +243,37 @@ class WorkflowToolManageService: :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_tool_id} not found') - - workflow_app: App = db.session.query(App).filter( - App.id == db_tool.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_tool_id} not found") + + workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() if workflow_app is None: - raise ValueError(f'App {db_tool.app_id} not found') + raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return { - 'name': db_tool.name, - 'label': db_tool.label, - 'workflow_tool_id': db_tool.id, - 'workflow_app_id': db_tool.app_id, - 'icon': json.loads(db_tool.icon), - 'description': db_tool.description, - 'parameters': jsonable_encoder(db_tool.parameter_configurations), - 'tool': ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + "name": db_tool.name, + "label": db_tool.label, + "workflow_tool_id": db_tool.id, + "workflow_app_id": db_tool.app_id, + "icon": json.loads(db_tool.icon), + "description": db_tool.description, + "parameters": jsonable_encoder(db_tool.parameter_configurations), + "tool": ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ), - 'synced': workflow_app.workflow.version == db_tool.version, - 'privacy_policy': db_tool.privacy_policy, + "synced": workflow_app.workflow.version == db_tool.version, + "privacy_policy": db_tool.privacy_policy, } - + @classmethod def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict: """ @@ -273,40 +283,37 @@ class WorkflowToolManageService: :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == workflow_app_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_app_id} not found') - - workflow_app: App = db.session.query(App).filter( - App.id == db_tool.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_app_id} not found") + + workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() if workflow_app is None: - raise ValueError(f'App {db_tool.app_id} not found') + raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return { - 'name': db_tool.name, - 'label': db_tool.label, - 'workflow_tool_id': db_tool.id, - 'workflow_app_id': db_tool.app_id, - 'icon': json.loads(db_tool.icon), - 'description': db_tool.description, - 'parameters': jsonable_encoder(db_tool.parameter_configurations), - 'tool': ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + "name": db_tool.name, + "label": db_tool.label, + "workflow_tool_id": db_tool.id, + "workflow_app_id": db_tool.app_id, + "icon": json.loads(db_tool.icon), + "description": db_tool.description, + "parameters": jsonable_encoder(db_tool.parameter_configurations), + "tool": ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ), - 'synced': workflow_app.workflow.version == db_tool.version, - 'privacy_policy': db_tool.privacy_policy + "synced": workflow_app.workflow.version == db_tool.version, + "privacy_policy": db_tool.privacy_policy, } - + @classmethod def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]: """ @@ -316,19 +323,19 @@ class WorkflowToolManageService: :param workflow_app_id: the workflow app id :return: the list of tools """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_tool_id} not found') + raise ValueError(f"Tool {workflow_tool_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return [ ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ) - ] \ No newline at end of file + ] diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 232d294325..3c67351335 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -7,10 +7,10 @@ from models.dataset import Dataset, DocumentSegment class VectorService: - @classmethod - def create_segments_vector(cls, keywords_list: Optional[list[list[str]]], - segments: list[DocumentSegment], dataset: Dataset): + def create_segments_vector( + cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset + ): documents = [] for segment in segments: document = Document( @@ -20,14 +20,12 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # save vector index - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) vector.add_texts(documents, duplicate_check=True) # save keyword index @@ -50,13 +48,11 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # update vector index - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) vector.delete_by_ids([segment.index_node_id]) vector.add_texts([document], duplicate_check=True) diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 269a048134..d7ccc964cb 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -11,18 +11,29 @@ from services.conversation_service import ConversationService class WebConversationService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, invoke_from: InvokeFrom, - pinned: Optional[bool] = None, - sort_by='-updated_at') -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + invoke_from: InvokeFrom, + pinned: Optional[bool] = None, + sort_by="-updated_at", + ) -> InfiniteScrollPagination: include_ids = None exclude_ids = None if pinned is not None: - pinned_conversations = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).order_by(PinnedConversation.created_at.desc()).all() + pinned_conversations = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .order_by(PinnedConversation.created_at.desc()) + .all() + ) pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] if pinned: include_ids = pinned_conversation_ids @@ -37,32 +48,34 @@ class WebConversationService: invoke_from=invoke_from, include_ids=include_ids, exclude_ids=exclude_ids, - sort_by=sort_by + sort_by=sort_by, ) @classmethod def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - pinned_conversation = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.conversation_id == conversation_id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).first() + pinned_conversation = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .first() + ) if pinned_conversation: return conversation = ConversationService.get_conversation( - app_model=app_model, - conversation_id=conversation_id, - user=user + app_model=app_model, conversation_id=conversation_id, user=user ) pinned_conversation = PinnedConversation( app_id=app_model.id, conversation_id=conversation.id, - created_by_role='account' if isinstance(user, Account) else 'end_user', - created_by=user.id + created_by_role="account" if isinstance(user, Account) else "end_user", + created_by=user.id, ) db.session.add(pinned_conversation) @@ -70,12 +83,16 @@ class WebConversationService: @classmethod def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - pinned_conversation = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.conversation_id == conversation_id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).first() + pinned_conversation = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .first() + ) if not pinned_conversation: return diff --git a/api/services/website_service.py b/api/services/website_service.py index c166b01237..6dff35d63f 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -11,161 +11,126 @@ from services.auth.api_key_auth_service import ApiKeyAuthService class WebsiteService: - @classmethod def document_create_args_validate(cls, args: dict): - if 'url' not in args or not args['url']: - raise ValueError('url is required') - if 'options' not in args or not args['options']: - raise ValueError('options is required') - if 'limit' not in args['options'] or not args['options']['limit']: - raise ValueError('limit is required') + if "url" not in args or not args["url"]: + raise ValueError("url is required") + if "options" not in args or not args["options"]: + raise ValueError("options is required") + if "limit" not in args["options"] or not args["options"]["limit"]: + raise ValueError("limit is required") @classmethod def crawl_url(cls, args: dict) -> dict: - provider = args.get('provider') - url = args.get('url') - options = args.get('options') - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, - 'website', - provider) - if provider == 'firecrawl': + provider = args.get("provider") + url = args.get("url") + options = args.get("options") + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, - token=credentials.get('config').get('api_key') + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) - crawl_sub_pages = options.get('crawl_sub_pages', False) - only_main_content = options.get('only_main_content', False) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + crawl_sub_pages = options.get("crawl_sub_pages", False) + only_main_content = options.get("only_main_content", False) if not crawl_sub_pages: params = { - 'crawlerOptions': { + "crawlerOptions": { "includes": [], "excludes": [], "generateImgAltText": True, "limit": 1, - 'returnOnlyUrls': False, - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } + "returnOnlyUrls": False, + "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}, } } else: - includes = options.get('includes').split(',') if options.get('includes') else [] - excludes = options.get('excludes').split(',') if options.get('excludes') else [] + includes = options.get("includes").split(",") if options.get("includes") else [] + excludes = options.get("excludes").split(",") if options.get("excludes") else [] params = { - 'crawlerOptions': { + "crawlerOptions": { "includes": includes if includes else [], "excludes": excludes if excludes else [], "generateImgAltText": True, - "limit": options.get('limit', 1), - 'returnOnlyUrls': False, - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } + "limit": options.get("limit", 1), + "returnOnlyUrls": False, + "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}, } } - if options.get('max_depth'): - params['crawlerOptions']['maxDepth'] = options.get('max_depth') + if options.get("max_depth"): + params["crawlerOptions"]["maxDepth"] = options.get("max_depth") job_id = firecrawl_app.crawl_url(url, params) - website_crawl_time_cache_key = f'website_crawl_{job_id}' + website_crawl_time_cache_key = f"website_crawl_{job_id}" time = str(datetime.datetime.now().timestamp()) redis_client.setex(website_crawl_time_cache_key, 3600, time) - return { - 'status': 'active', - 'job_id': job_id - } + return {"status": "active", "job_id": job_id} else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") @classmethod def get_crawl_status(cls, job_id: str, provider: str) -> dict: - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, - 'website', - provider) - if provider == 'firecrawl': + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, - token=credentials.get('config').get('api_key') + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) crawl_status_data = { - 'status': result.get('status', 'active'), - 'job_id': job_id, - 'total': result.get('total', 0), - 'current': result.get('current', 0), - 'data': result.get('data', []) + "status": result.get("status", "active"), + "job_id": job_id, + "total": result.get("total", 0), + "current": result.get("current", 0), + "data": result.get("data", []), } - if crawl_status_data['status'] == 'completed': - website_crawl_time_cache_key = f'website_crawl_{job_id}' + if crawl_status_data["status"] == "completed": + website_crawl_time_cache_key = f"website_crawl_{job_id}" start_time = redis_client.get(website_crawl_time_cache_key) if start_time: end_time = datetime.datetime.now().timestamp() time_consuming = abs(end_time - float(start_time)) - crawl_status_data['time_consuming'] = f"{time_consuming:.2f}" + crawl_status_data["time_consuming"] = f"{time_consuming:.2f}" redis_client.delete(website_crawl_time_cache_key) else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") return crawl_status_data @classmethod def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, - 'website', - provider) - if provider == 'firecrawl': - file_key = 'website_files/' + job_id + '.txt' + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if provider == "firecrawl": + file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): data = storage.load_once(file_key) if data: - data = json.loads(data.decode('utf-8')) + data = json.loads(data.decode("utf-8")) else: # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=credentials.get('config').get('api_key') - ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) - if result.get('status') != 'completed': - raise ValueError('Crawl job is not completed') - data = result.get('data') + if result.get("status") != "completed": + raise ValueError("Crawl job is not completed") + data = result.get("data") if data: for item in data: - if item.get('source_url') == url: + if item.get("source_url") == url: return item return None else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") @classmethod def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, - 'website', - provider) - if provider == 'firecrawl': + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=credentials.get('config').get('api_key') - ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) - params = { - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } - } + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + params = {"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}} result = firecrawl_app.scrape_url(url, params) return result else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index c4d3d27631..b4f0882a3a 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -10,7 +10,6 @@ from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus class WorkflowAppService: - def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination: """ Get paginate workflow app logs @@ -18,20 +17,14 @@ class WorkflowAppService: :param args: request args :return: """ - query = ( - db.select(WorkflowAppLog) - .where( - WorkflowAppLog.tenant_id == app_model.tenant_id, - WorkflowAppLog.app_id == app_model.id - ) + query = db.select(WorkflowAppLog).where( + WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id ) - status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None - keyword = args['keyword'] + status = WorkflowRunStatus.value_of(args.get("status")) if args.get("status") else None + keyword = args["keyword"] if keyword or status: - query = query.join( - WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id - ) + query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) if keyword: keyword_like_val = f"%{args['keyword'][:30]}%" @@ -39,7 +32,7 @@ class WorkflowAppService: WorkflowRun.inputs.ilike(keyword_like_val), WorkflowRun.outputs.ilike(keyword_like_val), # filter keyword by end user session id if created by end user role - and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_like_val)) + and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), ] # filter keyword by workflow run id @@ -49,23 +42,16 @@ class WorkflowAppService: query = query.outerjoin( EndUser, - and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value) + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value), ).filter(or_(*keyword_conditions)) if status: # join with workflow_run and filter by status - query = query.filter( - WorkflowRun.status == status.value - ) + query = query.filter(WorkflowRun.status == status.value) query = query.order_by(WorkflowAppLog.created_at.desc()) - pagination = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + pagination = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return pagination diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index ccce38ada0..b7b3abeaa2 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -18,6 +18,7 @@ class WorkflowRunService: :param app_model: app model :param args: request args """ + class WorkflowWithMessage: message_id: str conversation_id: str @@ -33,9 +34,7 @@ class WorkflowRunService: with_message_workflow_runs = [] for workflow_run in pagination.data: message = workflow_run.message - with_message_workflow_run = WorkflowWithMessage( - workflow_run=workflow_run - ) + with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run) if message: with_message_workflow_run.message_id = message.id with_message_workflow_run.conversation_id = message.conversation_id @@ -53,26 +52,30 @@ class WorkflowRunService: :param app_model: app model :param args: request args """ - limit = int(args.get('limit', 20)) + limit = int(args.get("limit", 20)) base_query = db.session.query(WorkflowRun).filter( WorkflowRun.tenant_id == app_model.tenant_id, WorkflowRun.app_id == app_model.id, - WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, ) - if args.get('last_id'): + if args.get("last_id"): last_workflow_run = base_query.filter( - WorkflowRun.id == args.get('last_id'), + WorkflowRun.id == args.get("last_id"), ).first() if not last_workflow_run: - raise ValueError('Last workflow run not exists') + raise ValueError("Last workflow run not exists") - workflow_runs = base_query.filter( - WorkflowRun.created_at < last_workflow_run.created_at, - WorkflowRun.id != last_workflow_run.id - ).order_by(WorkflowRun.created_at.desc()).limit(limit).all() + workflow_runs = ( + base_query.filter( + WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id + ) + .order_by(WorkflowRun.created_at.desc()) + .limit(limit) + .all() + ) else: workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() @@ -81,17 +84,13 @@ class WorkflowRunService: current_page_first_workflow_run = workflow_runs[-1] rest_count = base_query.filter( WorkflowRun.created_at < current_page_first_workflow_run.created_at, - WorkflowRun.id != current_page_first_workflow_run.id + WorkflowRun.id != current_page_first_workflow_run.id, ).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=workflow_runs, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: """ @@ -100,11 +99,15 @@ class WorkflowRunService: :param app_model: app model :param run_id: workflow run id """ - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.tenant_id == app_model.tenant_id, - WorkflowRun.app_id == app_model.id, - WorkflowRun.id == run_id, - ).first() + workflow_run = ( + db.session.query(WorkflowRun) + .filter( + WorkflowRun.tenant_id == app_model.tenant_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.id == run_id, + ) + .first() + ) return workflow_run @@ -117,12 +120,17 @@ class WorkflowRunService: if not workflow_run: return [] - node_executions = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.tenant_id == app_model.tenant_id, - WorkflowNodeExecution.app_id == app_model.id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == run_id, - ).order_by(WorkflowNodeExecution.index.desc()).all() + node_executions = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == app_model.tenant_id, + WorkflowNodeExecution.app_id == app_model.id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == run_id, + ) + .order_by(WorkflowNodeExecution.index.desc()) + .all() + ) return node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index f016f2937c..9c36193328 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -39,11 +39,13 @@ class WorkflowService: Get draft workflow """ # fetch draft workflow by app_model - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.version == 'draft' - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft" + ) + .first() + ) # return draft workflow return workflow @@ -57,11 +59,15 @@ class WorkflowService: return None # fetch published workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == app_model.workflow_id - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model.workflow_id, + ) + .first() + ) return workflow @@ -87,10 +93,7 @@ class WorkflowService: raise WorkflowHashNotEqualError() # validate features structure - self.validate_features_structure( - app_model=app_model, - features=features - ) + self.validate_features_structure(app_model=app_model, features=features) # create draft workflow if not found if not workflow: @@ -98,7 +101,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(app_model.mode).value, - version='draft', + version="draft", graph=json.dumps(graph), features=json.dumps(features), created_by=account.id, @@ -124,9 +127,7 @@ class WorkflowService: # return draft workflow return workflow - def publish_workflow(self, app_model: App, - account: Account, - draft_workflow: Optional[Workflow] = None) -> Workflow: + def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow: """ Publish workflow from draft @@ -139,7 +140,7 @@ class WorkflowService: draft_workflow = self.get_draft_workflow(app_model=app_model) if not draft_workflow: - raise ValueError('No valid workflow found.') + raise ValueError("No valid workflow found.") # create new workflow workflow = Workflow( @@ -201,17 +202,16 @@ class WorkflowService: return default_config - def run_draft_workflow_node(self, app_model: App, - node_id: str, - user_inputs: dict, - account: Account) -> WorkflowNodeExecution: + def run_draft_workflow_node( + self, app_model: App, node_id: str, user_inputs: dict, account: Account + ) -> WorkflowNodeExecution: """ Run draft workflow node """ # fetch draft workflow by app_model draft_workflow = self.get_draft_workflow(app_model=app_model) if not draft_workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") # run draft workflow node start_at = time.perf_counter() @@ -235,7 +235,7 @@ class WorkflowService: if not node_run_result: raise ValueError('Node run failed with no run result') - + run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: @@ -290,16 +290,16 @@ class WorkflowService: workflow_converter = WorkflowConverter() if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]: - raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.') + raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow new_app = workflow_converter.convert_to_workflow( app_model=app_model, account=account, - name=args.get('name'), - icon_type=args.get('icon_type'), - icon=args.get('icon'), - icon_background=args.get('icon_background'), + name=args.get("name"), + icon_type=args.get("icon_type"), + icon=args.get("icon"), + icon_background=args.get("icon_background"), ) return new_app @@ -307,15 +307,11 @@ class WorkflowService: def validate_features_structure(self, app_model: App, features: dict) -> dict: if app_model.mode == AppMode.ADVANCED_CHAT.value: return AdvancedChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=features, - only_structure_validate=True + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) elif app_model.mode == AppMode.WORKFLOW.value: return WorkflowAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=features, - only_structure_validate=True + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) else: raise ValueError(f"Invalid app mode: {app_model.mode}") diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 2bcbe5c6f6..8fcb12b1cb 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,4 +1,3 @@ - from flask_login import current_user from configs import dify_config @@ -14,34 +13,40 @@ class WorkspaceService: if not tenant: return None tenant_info = { - 'id': tenant.id, - 'name': tenant.name, - 'plan': tenant.plan, - 'status': tenant.status, - 'created_at': tenant.created_at, - 'in_trail': True, - 'trial_end_reason': None, - 'role': 'normal', + "id": tenant.id, + "name": tenant.name, + "plan": tenant.plan, + "status": tenant.status, + "created_at": tenant.created_at, + "in_trail": True, + "trial_end_reason": None, + "role": "normal", } # Get role of user - tenant_account_join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.account_id == current_user.id - ).first() - tenant_info['role'] = tenant_account_join.role + tenant_account_join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) + .first() + ) + tenant_info["role"] = tenant_account_join.role - can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo + can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo - if can_replace_logo and TenantService.has_roles(tenant, - [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): + if can_replace_logo and TenantService.has_roles( + tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN] + ): base_url = dify_config.FILES_URL - replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None - remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) + replace_webapp_logo = ( + f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" + if tenant.custom_config_dict.get("replace_webapp_logo") + else None + ) + remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) - tenant_info['custom_config'] = { - 'remove_webapp_brand': remove_webapp_brand, - 'replace_webapp_logo': replace_webapp_logo, + tenant_info["custom_config"] = { + "remove_webapp_brand": remove_webapp_brand, + "replace_webapp_logo": replace_webapp_logo, } return tenant_info diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index e0a1b21909..b50876cc79 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -14,7 +14,7 @@ from models.dataset import Document as DatasetDocument from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def add_document_to_index_task(dataset_document_id: str): """ Async Add document to index @@ -22,24 +22,25 @@ def add_document_to_index_task(dataset_document_id: str): Usage: add_document_to_index.delay(document_id) """ - logging.info(click.style('Start add document to index: {}'.format(dataset_document_id), fg='green')) + logging.info(click.style("Start add document to index: {}".format(dataset_document_id), fg="green")) start_at = time.perf_counter() dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first() if not dataset_document: - raise NotFound('Document not found') + raise NotFound("Document not found") - if dataset_document.indexing_status != 'completed': + if dataset_document.indexing_status != "completed": return - indexing_cache_key = 'document_{}_indexing'.format(dataset_document.id) + indexing_cache_key = "document_{}_indexing".format(dataset_document.id) try: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.enabled == True - ) \ - .order_by(DocumentSegment.position.asc()).all() + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) documents = [] for segment in segments: @@ -50,7 +51,7 @@ def add_document_to_index_task(dataset_document_id: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) @@ -58,7 +59,7 @@ def add_document_to_index_task(dataset_document_id: str): dataset = dataset_document.dataset if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -66,12 +67,15 @@ def add_document_to_index_task(dataset_document_id: str): end_at = time.perf_counter() logging.info( - click.style('Document added to index: {} latency: {}'.format(dataset_document.id, end_at - start_at), fg='green')) + click.style( + "Document added to index: {} latency: {}".format(dataset_document.id, end_at - start_at), fg="green" + ) + ) except Exception as e: logging.exception("add document to index failed") dataset_document.enabled = False dataset_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - dataset_document.status = 'error' + dataset_document.status = "error" dataset_document.error = str(e) db.session.commit() finally: diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index b3aa8b596c..25c55bcfaf 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -10,9 +10,10 @@ from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str, - collection_binding_id: str): +@shared_task(queue="dataset") +def add_annotation_to_index_task( + annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str +): """ Add annotation to index. :param annotation_id: annotation id @@ -23,38 +24,34 @@ def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: s Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style('Start build index for annotation: {}'.format(annotation_id), fg='green')) + logging.info(click.style("Start build index for annotation: {}".format(annotation_id), fg="green")) start_at = time.perf_counter() try: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, - 'annotation' + collection_binding_id, "annotation" ) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) document = Document( - page_content=question, - metadata={ - "annotation_id": annotation_id, - "app_id": app_id, - "doc_id": annotation_id - } + page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id} ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) vector.create([document], duplicate_check=True) end_at = time.perf_counter() logging.info( click.style( - 'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at), - fg='green')) + "Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at), + fg="green", + ) + ) except Exception: logging.exception("Build index for annotation failed") diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 6e6b16045d..fa7e5ac919 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -14,9 +14,8 @@ from models.model import App, AppAnnotationSetting, MessageAnnotation from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, - user_id: str): +@shared_task(queue="dataset") +def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, user_id: str): """ Add annotation to index. :param job_id: job_id @@ -26,72 +25,66 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: :param user_id: user_id """ - logging.info(click.style('Start batch import annotation: {}'.format(job_id), fg='green')) + logging.info(click.style("Start batch import annotation: {}".format(job_id), fg="green")) start_at = time.perf_counter() - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == tenant_id, - App.status == 'normal' - ).first() + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() if app: try: documents = [] for content in content_list: annotation = MessageAnnotation( - app_id=app.id, - content=content['answer'], - question=content['question'], - account_id=user_id + app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id ) db.session.add(annotation) db.session.flush() document = Document( - page_content=content['question'], - metadata={ - "annotation_id": annotation.id, - "app_id": app_id, - "doc_id": annotation.id - } + page_content=content["question"], + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, ) documents.append(document) # if annotation reply is enabled , batch add annotations' index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - app_annotation_setting.collection_binding_id, - 'annotation' + dataset_collection_binding = ( + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + app_annotation_setting.collection_binding_id, "annotation" + ) ) if not dataset_collection_binding: raise NotFound("App annotation setting not found") dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) vector.create(documents, duplicate_check=True) db.session.commit() - redis_client.setex(indexing_cache_key, 600, 'completed') + redis_client.setex(indexing_cache_key, 600, "completed") end_at = time.perf_counter() logging.info( click.style( - 'Build index successful for batch import annotation: {} latency: {}'.format(job_id, end_at - start_at), - fg='green')) + "Build index successful for batch import annotation: {} latency: {}".format( + job_id, end_at - start_at + ), + fg="green", + ) + ) except Exception as e: db.session.rollback() - redis_client.setex(indexing_cache_key, 600, 'error') - indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) + redis_client.setex(indexing_cache_key, 600, "error") + indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) redis_client.setex(indexing_error_msg_key, 600, str(e)) logging.exception("Build index for batch import annotations failed") diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 81155a35e4..5758db53de 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -9,36 +9,33 @@ from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, - collection_binding_id: str): +@shared_task(queue="dataset") +def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, collection_binding_id: str): """ Async delete annotation index task """ - logging.info(click.style('Start delete app annotation index: {}'.format(app_id), fg='green')) + logging.info(click.style("Start delete app annotation index: {}".format(app_id), fg="green")) start_at = time.perf_counter() try: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, - 'annotation' + collection_binding_id, "annotation" ) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', - collection_binding_id=dataset_collection_binding.id + indexing_technique="high_quality", + collection_binding_id=dataset_collection_binding.id, ) try: - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - vector.delete_by_metadata_field('annotation_id', annotation_id) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("annotation_id", annotation_id) except Exception: logging.exception("Delete annotation index failed when annotation deleted.") end_at = time.perf_counter() logging.info( - click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), - fg='green')) + click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Annotation deleted index failed:{}".format(str(e))) - diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index e5e03c9b51..0f83dfdbd4 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -12,49 +12,44 @@ from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation -@shared_task(queue='dataset') +@shared_task(queue="dataset") def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): """ Async enable annotation reply task """ - logging.info(click.style('Start delete app annotations index: {}'.format(app_id), fg='green')) + logging.info(click.style("Start delete app annotations index: {}".format(app_id), fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == tenant_id, - App.status == 'normal' - ).first() + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() annotations_count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).count() if not app: raise NotFound("App not found") - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if not app_annotation_setting: raise NotFound("App annotation setting not found") - disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) - disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) + disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) + disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) try: - dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', - collection_binding_id=app_annotation_setting.collection_binding_id + indexing_technique="high_quality", + collection_binding_id=app_annotation_setting.collection_binding_id, ) try: if annotations_count > 0: - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - vector.delete_by_metadata_field('app_id', app_id) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("app_id", app_id) except Exception: logging.exception("Delete annotation index failed when annotation deleted.") - redis_client.setex(disable_app_annotation_job_key, 600, 'completed') + redis_client.setex(disable_app_annotation_job_key, 600, "completed") # delete annotation setting db.session.delete(app_annotation_setting) @@ -62,12 +57,12 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): end_at = time.perf_counter() logging.info( - click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), - fg='green')) + click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Annotation batch deleted index failed:{}".format(str(e))) - redis_client.setex(disable_app_annotation_job_key, 600, 'error') - disable_app_annotation_error_key = 'disable_app_annotation_error_{}'.format(str(job_id)) + redis_client.setex(disable_app_annotation_job_key, 600, "error") + disable_app_annotation_error_key = "disable_app_annotation_error_{}".format(str(job_id)) redis_client.setex(disable_app_annotation_error_key, 600, str(e)) finally: redis_client.delete(disable_app_annotation_key) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index fda8b7a250..82b70f6b71 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -15,37 +15,39 @@ from models.model import App, AppAnnotationSetting, MessageAnnotation from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_id: str, score_threshold: float, - embedding_provider_name: str, embedding_model_name: str): +@shared_task(queue="dataset") +def enable_annotation_reply_task( + job_id: str, + app_id: str, + user_id: str, + tenant_id: str, + score_threshold: float, + embedding_provider_name: str, + embedding_model_name: str, +): """ Async enable annotation reply task """ - logging.info(click.style('Start add app annotation to index: {}'.format(app_id), fg='green')) + logging.info(click.style("Start add app annotation to index: {}".format(app_id), fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == tenant_id, - App.status == 'normal' - ).first() + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() if not app: raise NotFound("App not found") annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all() - enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) - enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) + enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) + enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) try: documents = [] dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, - embedding_model_name, - 'annotation' + embedding_provider_name, embedding_model_name, "annotation" + ) + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() ) - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() if annotation_setting: annotation_setting.score_threshold = score_threshold annotation_setting.collection_binding_id = dataset_collection_binding.id @@ -58,48 +60,42 @@ def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_ score_threshold=score_threshold, collection_binding_id=dataset_collection_binding.id, created_user_id=user_id, - updated_user_id=user_id + updated_user_id=user_id, ) db.session.add(new_app_annotation_setting) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) if annotations: for annotation in annotations: document = Document( page_content=annotation.question, - metadata={ - "annotation_id": annotation.id, - "app_id": app_id, - "doc_id": annotation.id - } + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, ) documents.append(document) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) try: - vector.delete_by_metadata_field('app_id', app_id) + vector.delete_by_metadata_field("app_id", app_id) except Exception as e: - logging.info( - click.style('Delete annotation index error: {}'.format(str(e)), - fg='red')) + logging.info(click.style("Delete annotation index error: {}".format(str(e)), fg="red")) vector.create(documents) db.session.commit() - redis_client.setex(enable_app_annotation_job_key, 600, 'completed') + redis_client.setex(enable_app_annotation_job_key, 600, "completed") end_at = time.perf_counter() logging.info( - click.style('App annotations added to index: {} latency: {}'.format(app_id, end_at - start_at), - fg='green')) + click.style("App annotations added to index: {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Annotation batch created index failed:{}".format(str(e))) - redis_client.setex(enable_app_annotation_job_key, 600, 'error') - enable_app_annotation_error_key = 'enable_app_annotation_error_{}'.format(str(job_id)) + redis_client.setex(enable_app_annotation_job_key, 600, "error") + enable_app_annotation_error_key = "enable_app_annotation_error_{}".format(str(job_id)) redis_client.setex(enable_app_annotation_error_key, 600, str(e)) db.session.rollback() finally: diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 7219abd3cd..b685d84d07 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -10,9 +10,10 @@ from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str, - collection_binding_id: str): +@shared_task(queue="dataset") +def update_annotation_to_index_task( + annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str +): """ Update annotation to index. :param annotation_id: annotation id @@ -23,39 +24,35 @@ def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style('Start update index for annotation: {}'.format(annotation_id), fg='green')) + logging.info(click.style("Start update index for annotation: {}".format(annotation_id), fg="green")) start_at = time.perf_counter() try: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, - 'annotation' + collection_binding_id, "annotation" ) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) document = Document( - page_content=question, - metadata={ - "annotation_id": annotation_id, - "app_id": app_id, - "doc_id": annotation_id - } + page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id} ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - vector.delete_by_metadata_field('annotation_id', annotation_id) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("annotation_id", annotation_id) vector.add_texts([document]) end_at = time.perf_counter() logging.info( click.style( - 'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at), - fg='green')) + "Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at), + fg="green", + ) + ) except Exception: logging.exception("Build index for annotation failed") diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 67cc03bdeb..de7f0ddec1 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -16,9 +16,10 @@ from libs import helper from models.dataset import Dataset, Document, DocumentSegment -@shared_task(queue='dataset') -def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: str, document_id: str, - tenant_id: str, user_id: str): +@shared_task(queue="dataset") +def batch_create_segment_to_index_task( + job_id: str, content: list, dataset_id: str, document_id: str, tenant_id: str, user_id: str +): """ Async batch create segment to index :param job_id: @@ -30,44 +31,44 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s Usage: batch_create_segment_to_index_task.delay(segment_id) """ - logging.info(click.style('Start batch create segment jobId: {}'.format(job_id), fg='green')) + logging.info(click.style("Start batch create segment jobId: {}".format(job_id), fg="green")) start_at = time.perf_counter() - indexing_cache_key = 'segment_batch_import_{}'.format(job_id) + indexing_cache_key = "segment_batch_import_{}".format(job_id) try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset not exist.') + raise ValueError("Dataset not exist.") dataset_document = db.session.query(Document).filter(Document.id == document_id).first() if not dataset_document: - raise ValueError('Document not exist.') + raise ValueError("Document not exist.") - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - raise ValueError('Document is not available.') + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + raise ValueError("Document is not available.") document_segments = [] embedding_model = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) for segment in content: - content = segment['content'] + content = segment["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) if embedding_model else 0 - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == dataset_document.id - ).scalar() + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0 + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == dataset_document.id) + .scalar() + ) segment_document = DocumentSegment( tenant_id=tenant_id, dataset_id=dataset_id, @@ -80,20 +81,22 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s tokens=tokens, created_by=user_id, indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - status='completed', - completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + status="completed", + completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) - if dataset_document.doc_form == 'qa_model': - segment_document.answer = segment['answer'] + if dataset_document.doc_form == "qa_model": + segment_document.answer = segment["answer"] db.session.add(segment_document) document_segments.append(segment_document) # add index to db indexing_runner = IndexingRunner() indexing_runner.batch_add_segments(document_segments, dataset) db.session.commit() - redis_client.setex(indexing_cache_key, 600, 'completed') + redis_client.setex(indexing_cache_key, 600, "completed") end_at = time.perf_counter() - logging.info(click.style('Segment batch created job: {} latency: {}'.format(job_id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Segments batch created index failed:{}".format(str(e))) - redis_client.setex(indexing_cache_key, 600, 'error') + redis_client.setex(indexing_cache_key, 600, "error") diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 1f26c966c4..3624903801 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -19,9 +19,15 @@ from models.model import UploadFile # Add import statement for ValueError -@shared_task(queue='dataset') -def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, - index_struct: str, collection_binding_id: str, doc_form: str): +@shared_task(queue="dataset") +def clean_dataset_task( + dataset_id: str, + tenant_id: str, + indexing_technique: str, + index_struct: str, + collection_binding_id: str, + doc_form: str, +): """ Clean dataset when dataset deleted. :param dataset_id: dataset id @@ -33,7 +39,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style('Start clean dataset when dataset deleted: {}'.format(dataset_id), fg='green')) + logging.info(click.style("Start clean dataset when dataset deleted: {}".format(dataset_id), fg="green")) start_at = time.perf_counter() try: @@ -48,9 +54,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() if documents is None or len(documents) == 0: - logging.info(click.style('No documents found for dataset: {}'.format(dataset_id), fg='green')) + logging.info(click.style("No documents found for dataset: {}".format(dataset_id), fg="green")) else: - logging.info(click.style('Cleaning documents for dataset: {}'.format(dataset_id), fg='green')) + logging.info(click.style("Cleaning documents for dataset: {}".format(dataset_id), fg="green")) # Specify the index type before initializing the index processor if doc_form is None: raise ValueError("Index type must be specified.") @@ -71,15 +77,16 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, if documents: for document in documents: try: - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": if document.data_source_info: data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == document.tenant_id, - UploadFile.id == file_id - ).first() + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .first() + ) if not file: continue storage.delete(file.key) @@ -90,6 +97,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, db.session.commit() end_at = time.perf_counter() logging.info( - click.style('Cleaned dataset when dataset deleted: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + click.style( + "Cleaned dataset when dataset deleted: {} latency: {}".format(dataset_id, end_at - start_at), fg="green" + ) + ) except Exception: logging.exception("Cleaned dataset when dataset deleted failed") diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 0fd05615b6..ae2855aa2e 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -12,7 +12,7 @@ from models.dataset import Dataset, DocumentSegment from models.model import UploadFile -@shared_task(queue='dataset') +@shared_task(queue="dataset") def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: Optional[str]): """ Clean document when document deleted. @@ -23,14 +23,14 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i Usage: clean_document_task.delay(document_id, dataset_id) """ - logging.info(click.style('Start clean document when document deleted: {}'.format(document_id), fg='green')) + logging.info(click.style("Start clean document when document deleted: {}".format(document_id), fg="green")) start_at = time.perf_counter() try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() # check segment is exist @@ -44,9 +44,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i db.session.commit() if file_id: - file = db.session.query(UploadFile).filter( - UploadFile.id == file_id - ).first() + file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if file: try: storage.delete(file.key) @@ -57,6 +55,10 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i end_at = time.perf_counter() logging.info( - click.style('Cleaned document when document deleted: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + click.style( + "Cleaned document when document deleted: {} latency: {}".format(document_id, end_at - start_at), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when document deleted failed") diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 9b697b6351..75d9e03130 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -9,7 +9,7 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def clean_notion_document_task(document_ids: list[str], dataset_id: str): """ Clean document when document deleted. @@ -18,20 +18,20 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): Usage: clean_notion_document_task.delay(document_ids, dataset_id) """ - logging.info(click.style('Start clean document when import form notion document deleted: {}'.format(dataset_id), fg='green')) + logging.info( + click.style("Start clean document when import form notion document deleted: {}".format(dataset_id), fg="green") + ) start_at = time.perf_counter() try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() for document_id in document_ids: - document = db.session.query(Document).filter( - Document.id == document_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id).first() db.session.delete(document) segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() @@ -44,8 +44,12 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): db.session.commit() end_at = time.perf_counter() logging.info( - click.style('Clean document when import form notion document deleted end :: {} latency: {}'.format( - dataset_id, end_at - start_at), - fg='green')) + click.style( + "Clean document when import form notion document deleted end :: {} latency: {}".format( + dataset_id, end_at - start_at + ), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when import form notion document deleted failed") diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index d31286e4cc..26375743b6 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -14,7 +14,7 @@ from extensions.ext_redis import redis_client from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] = None): """ Async create segment to index @@ -22,23 +22,23 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] :param keywords: Usage: create_segment_to_index_task.delay(segment_id) """ - logging.info(click.style('Start create segment to index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start create segment to index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() if not segment: - raise NotFound('Segment not found') + raise NotFound("Segment not found") - if segment.status != 'waiting': + if segment.status != "waiting": return - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) try: # update segment status to indexing update_params = { DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() @@ -49,23 +49,23 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) dataset = segment.dataset if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) return index_type = dataset.doc_form @@ -75,18 +75,20 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] # update segment to completed update_params = { DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() end_at = time.perf_counter() - logging.info(click.style('Segment created to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment created to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("create segment to index failed") segment.enabled = False segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.status = 'error' + segment.status = "error" segment.error = str(e) db.session.commit() finally: diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index ce93e111e5..cfc54920e2 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -11,7 +11,7 @@ from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument -@shared_task(queue='dataset') +@shared_task(queue="dataset") def deal_dataset_vector_index_task(dataset_id: str, action: str): """ Async deal dataset from index @@ -19,41 +19,46 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): :param action: action Usage: deal_dataset_vector_index_task.delay(dataset_id, action) """ - logging.info(click.style('Start deal dataset vector index: {}'.format(dataset_id), fg='green')) + logging.info(click.style("Start deal dataset vector index: {}".format(dataset_id), fg="green")) start_at = time.perf_counter() try: - dataset = Dataset.query.filter_by( - id=dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: - raise Exception('Dataset not found') + raise Exception("Dataset not found") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": index_processor.clean(dataset, None, with_keywords=False) elif action == "add": - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)) \ - .update({"indexing_status": "indexing"}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) db.session.commit() for dataset_document in dataset_documents: try: # add from vector index - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.enabled == True - ) .order_by(DocumentSegment.position.asc()).all() + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) if segments: documents = [] for segment in segments: @@ -64,32 +69,39 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id) \ - .update({"indexing_status": "completed"}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id) \ - .update({"indexing_status": "error", "error": str(e)}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) db.session.commit() - elif action == 'update': - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() + elif action == "update": + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) # add new index if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)) \ - .update({"indexing_status": "indexing"}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) db.session.commit() # clean index @@ -98,10 +110,12 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): for dataset_document in dataset_documents: # update from vector index try: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.enabled == True - ).order_by(DocumentSegment.position.asc()).all() + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) if segments: documents = [] for segment in segments: @@ -112,23 +126,25 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id) \ - .update({"indexing_status": "completed"}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id) \ - .update({"indexing_status": "error", "error": str(e)}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) db.session.commit() - end_at = time.perf_counter() logging.info( - click.style('Deal dataset vector index: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green") + ) except Exception: logging.exception("Deal dataset vector index failed") diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index d79286cf3d..c3e0ea5d9f 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -10,7 +10,7 @@ from extensions.ext_redis import redis_client from models.dataset import Dataset, Document -@shared_task(queue='dataset') +@shared_task(queue="dataset") def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str): """ Async Remove segment from index @@ -21,22 +21,22 @@ def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_ Usage: delete_segment_from_index_task.delay(segment_id) """ - logging.info(click.style('Start delete segment from index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() - indexing_cache_key = 'segment_{}_delete_indexing'.format(segment_id) + indexing_cache_key = "segment_{}_delete_indexing".format(segment_id) try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment_id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan")) return dataset_document = db.session.query(Document).filter(Document.id == document_id).first() if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment_id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment_id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan")) return index_type = dataset_document.doc_form @@ -44,7 +44,9 @@ def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_ index_processor.clean(dataset, [index_node_id]) end_at = time.perf_counter() - logging.info(click.style('Segment deleted from index: {} latency: {}'.format(segment_id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green") + ) except Exception: logging.exception("delete segment from index failed") finally: diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 4788bf4e4b..15e1e50076 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -11,7 +11,7 @@ from extensions.ext_redis import redis_client from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def disable_segment_from_index_task(segment_id: str): """ Async disable segment from index @@ -19,33 +19,33 @@ def disable_segment_from_index_task(segment_id: str): Usage: disable_segment_from_index_task.delay(segment_id) """ - logging.info(click.style('Start disable segment from index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start disable segment from index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() if not segment: - raise NotFound('Segment not found') + raise NotFound("Segment not found") - if segment.status != 'completed': - raise NotFound('Segment is not completed , disable action is not allowed.') + if segment.status != "completed": + raise NotFound("Segment is not completed , disable action is not allowed.") - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) try: dataset = segment.dataset if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) return index_type = dataset_document.doc_form @@ -53,7 +53,9 @@ def disable_segment_from_index_task(segment_id: str): index_processor.clean(dataset, [segment.index_node_id]) end_at = time.perf_counter() - logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment removed from index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) except Exception: logging.exception("remove segment from index failed") segment.enabled = True diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 4cced36ecd..9ea4c99649 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -14,7 +14,7 @@ from models.dataset import Dataset, Document, DocumentSegment from models.source import DataSourceOauthBinding -@shared_task(queue='dataset') +@shared_task(queue="dataset") def document_indexing_sync_task(dataset_id: str, document_id: str): """ Async update document @@ -23,50 +23,50 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): Usage: document_indexing_sync_task.delay(dataset_id, document_id) """ - logging.info(click.style('Start sync document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start sync document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") data_source_info = document.data_source_info_dict - if document.data_source_type == 'notion_import': - if not data_source_info or 'notion_page_id' not in data_source_info \ - or 'notion_workspace_id' not in data_source_info: + if document.data_source_type == "notion_import": + if ( + not data_source_info + or "notion_page_id" not in data_source_info + or "notion_workspace_id" not in data_source_info + ): raise ValueError("no notion page found") - workspace_id = data_source_info['notion_workspace_id'] - page_id = data_source_info['notion_page_id'] - page_type = data_source_info['type'] - page_edited_time = data_source_info['last_edited_time'] + workspace_id = data_source_info["notion_workspace_id"] + page_id = data_source_info["notion_page_id"] + page_type = data_source_info["type"] + page_edited_time = data_source_info["last_edited_time"] data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == document.tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise ValueError('Data source binding not found.') + raise ValueError("Data source binding not found.") loader = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=data_source_binding.access_token, - tenant_id=document.tenant_id + tenant_id=document.tenant_id, ) last_edited_time = loader.get_notion_last_edited_time() # check the page is updated if last_edited_time != page_edited_time: - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -74,7 +74,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Dataset not found') + raise Exception("Dataset not found") index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -89,7 +89,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): end_at = time.perf_counter() logging.info( - click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when document update data source or process rule failed") @@ -97,8 +103,10 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): indexing_runner = IndexingRunner() indexing_runner.run([document]) end_at = time.perf_counter() - logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + logging.info( + click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green") + ) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index cc93a1341e..e0da5f9ed0 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -12,7 +12,7 @@ from models.dataset import Dataset, Document from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def document_indexing_task(dataset_id: str, document_ids: list): """ Async process document @@ -36,16 +36,17 @@ def document_indexing_task(dataset_id: str, document_ids: list): if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: for document_id in document_ids: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) @@ -53,15 +54,14 @@ def document_indexing_task(dataset_id: str, document_ids: list): return for document_id in document_ids: - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) @@ -71,8 +71,8 @@ def document_indexing_task(dataset_id: str, document_ids: list): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index f129d93de8..6e681bcf4f 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -12,7 +12,7 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def document_indexing_update_task(dataset_id: str, document_id: str): """ Async update document @@ -21,18 +21,15 @@ def document_indexing_update_task(dataset_id: str, document_id: str): Usage: document_indexing_update_task.delay(dataset_id, document_id) """ - logging.info(click.style('Start update document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start update document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -40,7 +37,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Dataset not found') + raise Exception("Dataset not found") index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -57,7 +54,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str): db.session.commit() end_at = time.perf_counter() logging.info( - click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when document update data source or process rule failed") @@ -65,8 +68,8 @@ def document_indexing_update_task(dataset_id: str, document_id: str): indexing_runner = IndexingRunner() indexing_runner.run([document]) end_at = time.perf_counter() - logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 884e222d1b..0a7568c385 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -13,7 +13,7 @@ from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def duplicate_document_indexing_task(dataset_id: str, document_ids: list): """ Async process document @@ -37,16 +37,17 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: for document_id in document_ids: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) @@ -54,12 +55,11 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): return for document_id in document_ids: - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: # clean old data @@ -77,7 +77,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): db.session.delete(segment) db.session.commit() - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.utcnow() documents.append(document) db.session.add(document) @@ -87,8 +87,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index e37c06855d..1412ad9ec7 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -13,7 +13,7 @@ from extensions.ext_redis import redis_client from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def enable_segment_to_index_task(segment_id: str): """ Async enable segment to index @@ -21,17 +21,17 @@ def enable_segment_to_index_task(segment_id: str): Usage: enable_segment_to_index_task.delay(segment_id) """ - logging.info(click.style('Start enable segment to index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start enable segment to index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() if not segment: - raise NotFound('Segment not found') + raise NotFound("Segment not found") - if segment.status != 'completed': - raise NotFound('Segment is not completed, enable action is not allowed.') + if segment.status != "completed": + raise NotFound("Segment is not completed, enable action is not allowed.") - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) try: document = Document( @@ -41,23 +41,23 @@ def enable_segment_to_index_task(segment_id: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) dataset = segment.dataset if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() @@ -65,12 +65,14 @@ def enable_segment_to_index_task(segment_id: str): index_processor.load(dataset, [document]) end_at = time.perf_counter() - logging.info(click.style('Segment enabled to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment enabled to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("enable segment to index failed") segment.enabled = False segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.status = 'error' + segment.status = "error" segment.error = str(e) db.session.commit() finally: diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index a46eafa797..4ef6e29994 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -9,7 +9,7 @@ from configs import dify_config from extensions.ext_mail import mail -@shared_task(queue='mail') +@shared_task(queue="mail") def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str): """ Async Send invite member mail @@ -24,31 +24,38 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam if not mail.is_inited(): return - logging.info(click.style('Start send invite member mail to {} in workspace {}'.format(to, workspace_name), - fg='green')) + logging.info( + click.style("Start send invite member mail to {} in workspace {}".format(to, workspace_name), fg="green") + ) start_at = time.perf_counter() # send invite member mail using different languages try: - url = f'{dify_config.CONSOLE_WEB_URL}/activate?token={token}' - if language == 'zh-Hans': - html_content = render_template('invite_member_mail_template_zh-CN.html', - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url) + url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}" + if language == "zh-Hans": + html_content = render_template( + "invite_member_mail_template_zh-CN.html", + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + ) mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) else: - html_content = render_template('invite_member_mail_template_en-US.html', - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url) + html_content = render_template( + "invite_member_mail_template_en-US.html", + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + ) mail.send(to=to, subject="Join Dify Workspace Now", html=html_content) end_at = time.perf_counter() logging.info( - click.style('Send invite member mail to {} succeeded: latency: {}'.format(to, end_at - start_at), - fg='green')) + click.style( + "Send invite member mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) except Exception: - logging.exception("Send invite member mail to {} failed".format(to)) \ No newline at end of file + logging.exception("Send invite member mail to {} failed".format(to)) diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 4e1b8a8913..cbb78976ca 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -9,7 +9,7 @@ from configs import dify_config from extensions.ext_mail import mail -@shared_task(queue='mail') +@shared_task(queue="mail") def send_reset_password_mail_task(language: str, to: str, token: str): """ Async Send reset password mail @@ -20,26 +20,24 @@ def send_reset_password_mail_task(language: str, to: str, token: str): if not mail.is_inited(): return - logging.info(click.style('Start password reset mail to {}'.format(to), fg='green')) + logging.info(click.style("Start password reset mail to {}".format(to), fg="green")) start_at = time.perf_counter() # send reset password mail using different languages try: - url = f'{dify_config.CONSOLE_WEB_URL}/forgot-password?token={token}' - if language == 'zh-Hans': - html_content = render_template('reset_password_mail_template_zh-CN.html', - to=to, - url=url) + url = f"{dify_config.CONSOLE_WEB_URL}/forgot-password?token={token}" + if language == "zh-Hans": + html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, url=url) mail.send(to=to, subject="重置您的 Dify 密码", html=html_content) else: - html_content = render_template('reset_password_mail_template_en-US.html', - to=to, - url=url) + html_content = render_template("reset_password_mail_template_en-US.html", to=to, url=url) mail.send(to=to, subject="Reset Your Dify Password", html=html_content) end_at = time.perf_counter() logging.info( - click.style('Send password reset mail to {} succeeded: latency: {}'.format(to, end_at - start_at), - fg='green')) + click.style( + "Send password reset mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) except Exception: logging.exception("Send password reset mail to {} failed".format(to)) diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 6b4cab55b3..260069c6e2 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -10,7 +10,7 @@ from models.model import Message from models.workflow import WorkflowRun -@shared_task(queue='ops_trace') +@shared_task(queue="ops_trace") def process_trace_tasks(tasks_data): """ Async process trace tasks @@ -20,17 +20,17 @@ def process_trace_tasks(tasks_data): """ from core.ops.ops_trace_manager import OpsTraceManager - trace_info = tasks_data.get('trace_info') - app_id = tasks_data.get('app_id') - trace_info_type = tasks_data.get('trace_info_type') + trace_info = tasks_data.get("trace_info") + app_id = tasks_data.get("app_id") + trace_info_type = tasks_data.get("trace_info_type") trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) - if trace_info.get('message_data'): - trace_info['message_data'] = Message.from_dict(data=trace_info['message_data']) - if trace_info.get('workflow_data'): - trace_info['workflow_data'] = WorkflowRun.from_dict(data=trace_info['workflow_data']) - if trace_info.get('documents'): - trace_info['documents'] = [Document(**doc) for doc in trace_info['documents']] + if trace_info.get("message_data"): + trace_info["message_data"] = Message.from_dict(data=trace_info["message_data"]) + if trace_info.get("workflow_data"): + trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"]) + if trace_info.get("documents"): + trace_info["documents"] = [Document(**doc) for doc in trace_info["documents"]] try: if trace_instance: diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 02278f512b..18bae14ffa 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -10,7 +10,7 @@ from extensions.ext_database import db from models.dataset import Document -@shared_task(queue='dataset') +@shared_task(queue="dataset") def recover_document_indexing_task(dataset_id: str, document_id: str): """ Async recover document @@ -19,16 +19,13 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): Usage: recover_document_indexing_task.delay(dataset_id, document_id) """ - logging.info(click.style('Recover document: {}'.format(document_id), fg='green')) + logging.info(click.style("Recover document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") try: indexing_runner = IndexingRunner() @@ -39,8 +36,10 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): elif document.indexing_status == "indexing": indexing_runner.run_in_indexing_status(document) end_at = time.perf_counter() - logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + logging.info( + click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green") + ) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 4efe7ee38c..66f78636ec 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -33,9 +33,9 @@ from models.web import PinnedConversation, SavedMessage from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun -@shared_task(queue='app_deletion', bind=True, max_retries=3) +@shared_task(queue="app_deletion", bind=True, max_retries=3) def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): - logging.info(click.style(f'Start deleting app and related data: {tenant_id}:{app_id}', fg='green')) + logging.info(click.style(f"Start deleting app and related data: {tenant_id}:{app_id}", fg="green")) start_at = time.perf_counter() try: # Delete related data @@ -59,13 +59,14 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_conversation_variables(app_id=app_id) end_at = time.perf_counter() - logging.info(click.style(f'App and related data deleted: {app_id} latency: {end_at - start_at}', fg='green')) + logging.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green")) except SQLAlchemyError as e: logging.exception( - click.style(f"Database error occurred while deleting app {app_id} and related data", fg='red')) + click.style(f"Database error occurred while deleting app {app_id} and related data", fg="red") + ) raise self.retry(exc=e, countdown=60) # Retry after 60 seconds except Exception as e: - logging.exception(click.style(f"Error occurred while deleting app {app_id} and related data", fg='red')) + logging.exception(click.style(f"Error occurred while deleting app {app_id} and related data", fg="red")) raise self.retry(exc=e, countdown=60) # Retry after 60 seconds @@ -77,7 +78,7 @@ def _delete_app_model_configs(tenant_id: str, app_id: str): """select id from app_model_configs where app_id=:app_id limit 1000""", {"app_id": app_id}, del_model_config, - "app model config" + "app model config", ) @@ -85,12 +86,7 @@ def _delete_app_site(tenant_id: str, app_id: str): def del_site(site_id: str): db.session.query(Site).filter(Site.id == site_id).delete(synchronize_session=False) - _delete_records( - """select id from sites where app_id=:app_id limit 1000""", - {"app_id": app_id}, - del_site, - "site" - ) + _delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site") def _delete_app_api_tokens(tenant_id: str, app_id: str): @@ -98,10 +94,7 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str): db.session.query(ApiToken).filter(ApiToken.id == api_token_id).delete(synchronize_session=False) _delete_records( - """select id from api_tokens where app_id=:app_id limit 1000""", - {"app_id": app_id}, - del_api_token, - "api token" + """select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token" ) @@ -113,44 +106,47 @@ def _delete_installed_apps(tenant_id: str, app_id: str): """select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_installed_app, - "installed app" + "installed app", ) def _delete_recommended_apps(tenant_id: str, app_id: str): def del_recommended_app(recommended_app_id: str): db.session.query(RecommendedApp).filter(RecommendedApp.id == recommended_app_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from recommended_apps where app_id=:app_id limit 1000""", {"app_id": app_id}, del_recommended_app, - "recommended app" + "recommended app", ) def _delete_app_annotation_data(tenant_id: str, app_id: str): def del_annotation_hit_history(annotation_hit_history_id: str): db.session.query(AppAnnotationHitHistory).filter( - AppAnnotationHitHistory.id == annotation_hit_history_id).delete(synchronize_session=False) + AppAnnotationHitHistory.id == annotation_hit_history_id + ).delete(synchronize_session=False) _delete_records( """select id from app_annotation_hit_histories where app_id=:app_id limit 1000""", {"app_id": app_id}, del_annotation_hit_history, - "annotation hit history" + "annotation hit history", ) def del_annotation_setting(annotation_setting_id: str): db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.id == annotation_setting_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from app_annotation_settings where app_id=:app_id limit 1000""", {"app_id": app_id}, del_annotation_setting, - "annotation setting" + "annotation setting", ) @@ -162,7 +158,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str): """select id from app_dataset_joins where app_id=:app_id limit 1000""", {"app_id": app_id}, del_dataset_join, - "dataset join" + "dataset join", ) @@ -174,7 +170,7 @@ def _delete_app_workflows(tenant_id: str, app_id: str): """select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow, - "workflow" + "workflow", ) @@ -186,89 +182,93 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): """select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow_run, - "workflow run" + "workflow run", ) def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def del_workflow_node_execution(workflow_node_execution_id: str): - db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == workflow_node_execution_id).delete(synchronize_session=False) + db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete( + synchronize_session=False + ) _delete_records( """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow_node_execution, - "workflow node execution" + "workflow node execution", ) def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): - db.session.query(WorkflowAppLog).filter(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False) + db.session.query(WorkflowAppLog).filter(WorkflowAppLog.id == workflow_app_log_id).delete( + synchronize_session=False + ) _delete_records( """select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow_app_log, - "workflow app log" + "workflow app log", ) def _delete_app_conversations(tenant_id: str, app_id: str): def del_conversation(conversation_id: str): db.session.query(PinnedConversation).filter(PinnedConversation.conversation_id == conversation_id).delete( - synchronize_session=False) + synchronize_session=False + ) db.session.query(Conversation).filter(Conversation.id == conversation_id).delete(synchronize_session=False) _delete_records( """select id from conversations where app_id=:app_id limit 1000""", {"app_id": app_id}, del_conversation, - "conversation" + "conversation", ) + def _delete_conversation_variables(*, app_id: str): stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id) with db.engine.connect() as conn: conn.execute(stmt) conn.commit() - logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg='green')) + logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green")) def _delete_app_messages(tenant_id: str, app_id: str): def del_message(message_id: str): db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message_id).delete( - synchronize_session=False) + synchronize_session=False + ) db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message_id).delete( - synchronize_session=False) - db.session.query(MessageChain).filter(MessageChain.message_id == message_id).delete( - synchronize_session=False) + synchronize_session=False + ) + db.session.query(MessageChain).filter(MessageChain.message_id == message_id).delete(synchronize_session=False) db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message_id).delete( - synchronize_session=False) + synchronize_session=False + ) db.session.query(MessageFile).filter(MessageFile.message_id == message_id).delete(synchronize_session=False) - db.session.query(SavedMessage).filter(SavedMessage.message_id == message_id).delete( - synchronize_session=False) + db.session.query(SavedMessage).filter(SavedMessage.message_id == message_id).delete(synchronize_session=False) db.session.query(Message).filter(Message.id == message_id).delete() _delete_records( - """select id from messages where app_id=:app_id limit 1000""", - {"app_id": app_id}, - del_message, - "message" + """select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message" ) def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def del_tool_provider(tool_provider_id: str): db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.id == tool_provider_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from tool_workflow_providers where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_tool_provider, - "tool workflow provider" + "tool workflow provider", ) @@ -280,7 +280,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str): """select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_tag_binding, - "tag binding" + "tag binding", ) @@ -292,20 +292,21 @@ def _delete_end_users(tenant_id: str, app_id: str): """select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_end_user, - "end user" + "end user", ) def _delete_trace_app_configs(tenant_id: str, app_id: str): def del_trace_app_config(trace_app_config_id: str): db.session.query(TraceAppConfig).filter(TraceAppConfig.id == trace_app_config_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from trace_app_config where app_id=:app_id limit 1000""", {"app_id": app_id}, del_trace_app_config, - "trace app config" + "trace app config", ) @@ -321,7 +322,7 @@ def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: s try: delete_func(record_id) db.session.commit() - logging.info(click.style(f"Deleted {name} {record_id}", fg='green')) + logging.info(click.style(f"Deleted {name} {record_id}", fg="green")) except Exception: logging.exception(f"Error occurred while deleting {name} {record_id}") continue diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index cff8dddc53..1909eaf341 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -11,7 +11,7 @@ from extensions.ext_redis import redis_client from models.dataset import Document, DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def remove_document_from_index_task(document_id: str): """ Async Remove document from index @@ -19,23 +19,23 @@ def remove_document_from_index_task(document_id: str): Usage: remove_document_from_index.delay(document_id) """ - logging.info(click.style('Start remove document segments from index: {}'.format(document_id), fg='green')) + logging.info(click.style("Start remove document segments from index: {}".format(document_id), fg="green")) start_at = time.perf_counter() document = db.session.query(Document).filter(Document.id == document_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") - if document.indexing_status != 'completed': + if document.indexing_status != "completed": return - indexing_cache_key = 'document_{}_indexing'.format(document.id) + indexing_cache_key = "document_{}_indexing".format(document.id) try: dataset = document.dataset if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() @@ -49,7 +49,10 @@ def remove_document_from_index_task(document_id: str): end_at = time.perf_counter() logging.info( - click.style('Document removed from index: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + click.style( + "Document removed from index: {} latency: {}".format(document.id, end_at - start_at), fg="green" + ) + ) except Exception: logging.exception("remove document from index failed") if not document.archived: diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 1114809b30..73471fd6e7 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -13,7 +13,7 @@ from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): """ Async process document @@ -27,22 +27,23 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() for document_id in document_ids: - retry_indexing_cache_key = 'document_{}_is_retried'.format(document_id) + retry_indexing_cache_key = "document_{}_is_retried".format(document_id) # check document limit features = FeatureService.get_features(dataset.tenant_id) try: if features.billing.enabled: vector_space = features.vector_space if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) @@ -50,11 +51,10 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): redis_client.delete(retry_indexing_cache_key) return - logging.info(click.style('Start retry document: {}'.format(document_id), fg='green')) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + logging.info(click.style("Start retry document: {}".format(document_id), fg="green")) + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) try: if document: # clean old data @@ -70,7 +70,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): db.session.delete(segment) db.session.commit() - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() @@ -79,13 +79,13 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): indexing_runner.run([document]) redis_client.delete(retry_indexing_cache_key) except Exception as ex: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(ex) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) redis_client.delete(retry_indexing_cache_key) pass end_at = time.perf_counter() - logging.info(click.style('Retry dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Retry dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 320da8718a..99fb66e1f3 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -13,7 +13,7 @@ from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def sync_website_document_indexing_task(dataset_id: str, document_id: str): """ Async process document @@ -26,22 +26,23 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - sync_indexing_cache_key = 'document_{}_is_sync'.format(document_id) + sync_indexing_cache_key = "document_{}_is_sync".format(document_id) # check document limit features = FeatureService.get_features(dataset.tenant_id) try: if features.billing.enabled: vector_space = features.vector_space if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) @@ -49,11 +50,8 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): redis_client.delete(sync_indexing_cache_key) return - logging.info(click.style('Start sync website document: {}'.format(document_id), fg='green')) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + logging.info(click.style("Start sync website document: {}".format(document_id), fg="green")) + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() try: if document: # clean old data @@ -69,7 +67,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): db.session.delete(segment) db.session.commit() - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() @@ -78,13 +76,13 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): indexing_runner.run([document]) redis_client.delete(sync_indexing_cache_key) except Exception as ex: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(ex) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) redis_client.delete(sync_indexing_cache_key) pass end_at = time.perf_counter() - logging.info(click.style('Sync document: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + logging.info(click.style("Sync document: {} latency: {}".format(document_id, end_at - start_at), fg="green")) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx index 09569df8bf..84ec157323 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx @@ -15,13 +15,14 @@ import { } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { useShallow } from 'zustand/react/shallow' +import { useContextSelector } from 'use-context-selector' import s from './style.module.css' import cn from '@/utils/classnames' import { useStore } from '@/app/components/app/store' import AppSideBar from '@/app/components/app-sidebar' import type { NavIcon } from '@/app/components/app-sidebar/navLink' -import { fetchAppDetail } from '@/service/apps' -import { useAppContext } from '@/context/app-context' +import { fetchAppDetail, fetchAppSSO } from '@/service/apps' +import AppContext, { useAppContext } from '@/context/app-context' import Loading from '@/app/components/base/loading' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' @@ -52,6 +53,7 @@ const AppDetailLayout: FC = (props) => { icon: NavIcon selectedIcon: NavIcon }>>([]) + const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures) const getNavigations = useCallback((appId: string, isCurrentWorkspaceEditor: boolean, mode: string) => { const navs = [ @@ -114,8 +116,13 @@ const AppDetailLayout: FC = (props) => { router.replace(`/app/${appId}/configuration`) } else { - setAppDetail(res) + setAppDetail({ ...res, enable_sso: false }) setNavigation(getNavigations(appId, isCurrentWorkspaceEditor, res.mode)) + if (systemFeatures.enable_web_sso_switch_component) { + fetchAppSSO({ appId }).then((ssoRes) => { + setAppDetail({ ...res, enable_sso: ssoRes.enabled }) + }) + } } }).catch((e: any) => { if (e.status === 404) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx index 5fa9a2e406..3584e13733 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx @@ -2,22 +2,25 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' +import { useContext, useContextSelector } from 'use-context-selector' import AppCard from '@/app/components/app/overview/appCard' import Loading from '@/app/components/base/loading' import { ToastContext } from '@/app/components/base/toast' import { fetchAppDetail, + fetchAppSSO, + updateAppSSO, updateAppSiteAccessToken, updateAppSiteConfig, updateAppSiteStatus, } from '@/service/apps' -import type { App } from '@/types/app' +import type { App, AppSSO } from '@/types/app' import type { UpdateAppSiteCodeResponse } from '@/models/app' import { asyncRunSafe } from '@/utils' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import type { IAppCardProps } from '@/app/components/app/overview/appCard' import { useStore as useAppStore } from '@/app/components/app/store' +import AppContext from '@/context/app-context' export type ICardViewProps = { appId: string @@ -28,11 +31,20 @@ const CardView: FC = ({ appId }) => { const { notify } = useContext(ToastContext) const appDetail = useAppStore(state => state.appDetail) const setAppDetail = useAppStore(state => state.setAppDetail) + const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures) const updateAppDetail = async () => { - fetchAppDetail({ url: '/apps', id: appId }).then((res) => { - setAppDetail(res) - }) + try { + const res = await fetchAppDetail({ url: '/apps', id: appId }) + if (systemFeatures.enable_web_sso_switch_component) { + const ssoRes = await fetchAppSSO({ appId }) + setAppDetail({ ...res, enable_sso: ssoRes.enabled }) + } + else { + setAppDetail({ ...res }) + } + } + catch (error) { console.error(error) } } const handleCallbackResult = (err: Error | null, message?: string) => { @@ -81,6 +93,16 @@ const CardView: FC = ({ appId }) => { if (!err) localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') + if (systemFeatures.enable_web_sso_switch_component) { + const [sso_err] = await asyncRunSafe( + updateAppSSO({ id: appId, enabled: params.enable_sso }) as Promise, + ) + if (sso_err) { + handleCallbackResult(sso_err) + return + } + } + handleCallbackResult(err) } diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index 7744bab7ba..8e3d8f9ec6 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -10,7 +10,7 @@ import { TracingProvider } from './type' import ProviderConfigModal from './provider-config-modal' import Indicator from '@/app/components/header/indicator' import Switch from '@/app/components/base/switch' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' const I18N_PREFIX = 'app.tracing' @@ -121,12 +121,11 @@ const ConfigPopup: FC = ({ <> {providerAllNotConfigured ? ( - {switchContent} - - + ) : switchContent} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx index 9119deede8..934eb681b9 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx @@ -3,7 +3,7 @@ import { ChevronDoubleDownIcon } from '@heroicons/react/20/solid' import type { FC } from 'react' import { useTranslation } from 'react-i18next' import React, { useCallback } from 'react' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' const I18N_PREFIX = 'app.tracing' @@ -25,9 +25,8 @@ const ToggleFoldBtn: FC = ({ return ( // text-[0px] to hide spacing between tooltip elements
- {isFold && (
@@ -39,7 +38,7 @@ const ToggleFoldBtn: FC = ({
)} -
+
) } diff --git a/web/app/(commonLayout)/datasets/DatasetCard.tsx b/web/app/(commonLayout)/datasets/DatasetCard.tsx index 096d9d357e..f66c6bccf4 100644 --- a/web/app/(commonLayout)/datasets/DatasetCard.tsx +++ b/web/app/(commonLayout)/datasets/DatasetCard.tsx @@ -4,9 +4,7 @@ import { useContext } from 'use-context-selector' import { useRouter } from 'next/navigation' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' -import { - RiMoreFill, -} from '@remixicon/react' +import { RiMoreFill } from '@remixicon/react' import cn from '@/utils/classnames' import Confirm from '@/app/components/base/confirm' import { ToastContext } from '@/app/components/base/toast' @@ -129,10 +127,9 @@ const DatasetCard = ({
{dataset.name}
{!dataset.embedding_available && ( - {t('dataset.unavailable')} + {t('dataset.unavailable')} )} diff --git a/web/app/components/app-sidebar/basic.tsx b/web/app/components/app-sidebar/basic.tsx index f1388f41de..c939cb7bb3 100644 --- a/web/app/components/app-sidebar/basic.tsx +++ b/web/app/components/app-sidebar/basic.tsx @@ -1,10 +1,6 @@ import React from 'react' -import { - InformationCircleIcon, -} from '@heroicons/react/24/outline' -import Tooltip from '../base/tooltip' import AppIcon from '../base/app-icon' -import { randomString } from '@/utils' +import Tooltip from '@/app/components/base/tooltip' export type IAppBasicProps = { iconType?: 'app' | 'api' | 'dataset' | 'webapp' | 'notion' @@ -74,9 +70,17 @@ export default function AppBasic({ icon, icon_background, name, type, hoverTip,
{name} {hoverTip - && - - } + && + {hoverTip} +
+ } + popupClassName='ml-1' + triggerClassName='w-4 h-4 ml-1' + position='top' + /> + }
{type}
} diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index fb3ceadc0d..a07fe69821 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -9,7 +9,6 @@ import produce from 'immer' import { RiDeleteBinLine, RiErrorWarningFill, - RiQuestionLine, } from '@remixicon/react' import s from './style.module.css' import MessageTypeSelector from './message-type-selector' @@ -174,12 +173,12 @@ const AdvancedPromptInput: FC = ({
{t('appDebug.pageTitle.line1')}
- {t('appDebug.promptTip')} - } - selector='config-prompt-tooltip'> - - + popupContent={ +
+ {t('appDebug.promptTip')} +
+ } + /> )}
{canDelete && ( diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index adcfcdd126..69e01a8e22 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -3,9 +3,6 @@ import type { FC } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' -import { - RiQuestionLine, -} from '@remixicon/react' import produce from 'immer' import { useContext } from 'use-context-selector' import ConfirmAddVar from './confirm-add-var' @@ -156,12 +153,12 @@ const Prompt: FC = ({
{mode !== AppType.completion ? t('appDebug.chatSubTitle') : t('appDebug.completionSubTitle')}
{!readonly && ( - {t('appDebug.promptTip')} -
} - selector='config-prompt-tooltip'> - - + popupContent={ +
+ {t('appDebug.promptTip')} +
+ } + /> )}
diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index 82a220c6db..802528e0af 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -8,7 +8,6 @@ import { useContext } from 'use-context-selector' import produce from 'immer' import { RiDeleteBinLine, - RiQuestionLine, } from '@remixicon/react' import Panel from '../base/feature-panel' import EditModal from './config-modal' @@ -282,11 +281,13 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar
{t('appDebug.variableTitle')}
{!readonly && ( - - {t('appDebug.variableTip')} -
} selector='config-var-tooltip'> - - + + {t('appDebug.variableTip')} +
+ } + /> )} } diff --git a/web/app/components/app/configuration/config-vision/index.tsx b/web/app/components/app/configuration/config-vision/index.tsx index 9b12e059b5..515709bff1 100644 --- a/web/app/components/app/configuration/config-vision/index.tsx +++ b/web/app/components/app/configuration/config-vision/index.tsx @@ -2,9 +2,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import { useContext } from 'use-context-selector' import Panel from '../base/feature-panel' import ParamConfig from './param-config' @@ -33,11 +30,13 @@ const ConfigVision: FC = () => { title={
{t('appDebug.vision.name')}
- - {t('appDebug.vision.description')} -
} selector='config-vision-tooltip'> - - + + {t('appDebug.vision.description')} + + } + /> } headerRight={ diff --git a/web/app/components/app/configuration/config-vision/param-config-content.tsx b/web/app/components/app/configuration/config-vision/param-config-content.tsx index 89fad411e7..cb01f57254 100644 --- a/web/app/components/app/configuration/config-vision/param-config-content.tsx +++ b/web/app/components/app/configuration/config-vision/param-config-content.tsx @@ -3,9 +3,6 @@ import type { FC } from 'react' import React from 'react' import { useContext } from 'use-context-selector' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import RadioGroup from './radio-group' import ConfigContext from '@/context/debug-configuration' import { Resolution, TransferMethod } from '@/types/app' @@ -37,13 +34,15 @@ const ParamConfigContent: FC = () => {
{t('appDebug.vision.visionSettings.resolution')}
- - {t('appDebug.vision.visionSettings.resolutionTooltip').split('\n').map(item => ( -
{item}
- ))} -
} selector='config-resolution-tooltip'> - - + + {t('appDebug.vision.visionSettings.resolutionTooltip').split('\n').map(item => ( +
{item}
+ ))} +
+ } + /> {
{t('appDebug.voice.voiceSettings.language')}
- - {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( -
{item}
- ))} -
} selector='config-resolution-tooltip'> - -
+ + {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( +
{item}
+ ))} + + } + /> = ({ {icon}
{name}
{description} } - selector={`agent-setting-tooltip-${name}`} > -
diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index 16f2257c38..1280fd8928 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -7,13 +7,11 @@ import produce from 'immer' import { RiDeleteBinLine, RiHammerFill, - RiQuestionLine, } from '@remixicon/react' import { useFormattingChangedDispatcher } from '../../../debug/hooks' import SettingBuiltInTool from './setting-built-in-tool' import cn from '@/utils/classnames' import Panel from '@/app/components/app/configuration/base/feature-panel' -import Tooltip from '@/app/components/base/tooltip' import { InfoCircle } from '@/app/components/base/icons/src/vender/line/general' import OperationBtn from '@/app/components/app/configuration/base/operation-btn' import AppIcon from '@/app/components/base/app-icon' @@ -23,7 +21,7 @@ import type { AgentTool } from '@/types/app' import { type Collection, CollectionType } from '@/app/components/tools/types' import { MAX_TOOLS_NUM } from '@/config' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other' import AddToolModal from '@/app/components/tools/add-tool-modal' @@ -68,11 +66,13 @@ const AgentTools: FC = () => { title={
{t('appDebug.agent.tools.name')}
- - {t('appDebug.agent.tools.description')} -
} selector='config-tools-tooltip'> - - + + {t('appDebug.agent.tools.description')} +
+ } + /> } headerRight={ @@ -119,19 +119,20 @@ const AgentTools: FC = () => { className={cn((item.isDeleted || item.notAuthor) ? 'line-through opacity-50' : '', 'grow w-0 ml-2 leading-[18px] text-[13px] font-medium text-gray-800 truncate')} > {item.provider_type === CollectionType.builtIn ? item.provider_name : item.tool_label} - {item.tool_name} - +
{(item.isDeleted || item.notAuthor) ? (
-
{ if (item.notAuthor) @@ -139,7 +140,7 @@ const AgentTools: FC = () => { }}>
-
+
{ const newModelConfig = produce(modelConfig, (draft) => { @@ -155,16 +156,17 @@ const AgentTools: FC = () => { ) : (
- -
{ +
{ setCurrentTool(item) setIsShowSettingTool(true) }}>
- +
{ const newModelConfig = produce(modelConfig, (draft) => { diff --git a/web/app/components/app/configuration/dataset-config/card-item/index.tsx b/web/app/components/app/configuration/dataset-config/card-item/index.tsx index 7b369d9d79..a528b2288c 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/index.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/index.tsx @@ -39,10 +39,9 @@ const CardItem: FC = ({
{config.name}
{!config.embedding_available && ( - {t('dataset.unavailable')} + {t('dataset.unavailable')} )}
diff --git a/web/app/components/app/configuration/dataset-config/context-var/index.tsx b/web/app/components/app/configuration/dataset-config/context-var/index.tsx index be0ae47242..0de182b0a9 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/index.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/index.tsx @@ -2,9 +2,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import type { Props } from './var-picker' import VarPicker from './var-picker' import cn from '@/utils/classnames' @@ -24,13 +21,12 @@ const ContextVar: FC = (props) => {
{t('appDebug.feature.dataSet.queryVariable.title')}
- {t('appDebug.feature.dataSet.queryVariable.tip')} -
} - selector='context-var-tooltip' - > - - + popupContent={ +
+ {t('appDebug.feature.dataSet.queryVariable.tip')} +
+ } + />
diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 3656bf6ea7..7f55649dab 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -5,7 +5,6 @@ import type { FC } from 'react' import { useTranslation } from 'react-i18next' import { RiAlertFill, - RiQuestionLine, } from '@remixicon/react' import WeightedScore from './weighted-score' import TopKItem from '@/app/components/base/param-item/top-k-item' @@ -23,7 +22,7 @@ import ModelSelector from '@/app/components/header/account-setting/model-provide import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import type { ModelConfig } from '@/app/components/workflow/types' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import type { DataSet, @@ -173,7 +172,7 @@ const ConfigContent: FC = ({ title={(
{t('appDebug.datasetConfig.retrieveOneWay.title')} - {t('dataset.nTo1RetrievalLegacy')} @@ -181,7 +180,7 @@ const ConfigContent: FC = ({ )} >
legacy
-
+
)} description={t('appDebug.datasetConfig.retrieveOneWay.description')} @@ -250,12 +249,15 @@ const ConfigContent: FC = ({ onClick={() => handleRerankModeChange(option.value)} >
{option.label}
- {option.tips}
} - hideArrow - > - - + + {option.tips} +
+ } + popupClassName='ml-0.5' + triggerClassName='ml-0.5 w-3.5 h-3.5' + /> )) } @@ -281,9 +283,15 @@ const ConfigContent: FC = ({ ) }
{t('common.modelProvider.rerankModel.key')}
- {t('common.modelProvider.rerankModel.tip')}}> - - + + {t('common.modelProvider.rerankModel.tip')} + + } + popupClassName='ml-0.5' + triggerClassName='ml-0.5 w-3.5 h-3.5' + />
= ({
{t('common.modelProvider.systemReasoningModel.key')}
- - - + />
= ({ { currentModel && currentModel.status !== ModelStatusEnum.active && ( - + - + ) }
diff --git a/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx b/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx index e27eec46c8..199558f4aa 100644 --- a/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx +++ b/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx @@ -2,9 +2,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import Panel from '@/app/components/app/configuration/base/feature-panel' import SuggestedQuestionsAfterAnswerIcon from '@/app/components/app/configuration/base/icons/suggested-questions-after-answer-icon' import Tooltip from '@/app/components/base/tooltip' @@ -15,13 +12,15 @@ const SuggestedQuestionsAfterAnswer: FC = () => { return ( +
{t('appDebug.feature.suggestedQuestionsAfterAnswer.title')}
- - {t('appDebug.feature.suggestedQuestionsAfterAnswer.description')} -
} selector='suggestion-question-tooltip'> - - + + {t('appDebug.feature.suggestedQuestionsAfterAnswer.description')} +
+ } + /> } headerIcon={} diff --git a/web/app/components/app/configuration/prompt-value-panel/index.tsx b/web/app/components/app/configuration/prompt-value-panel/index.tsx index 0192024c83..dc9b0a4333 100644 --- a/web/app/components/app/configuration/prompt-value-panel/index.tsx +++ b/web/app/components/app/configuration/prompt-value-panel/index.tsx @@ -16,7 +16,7 @@ import { AppType, ModelModeType } from '@/types/app' import Select from '@/app/components/base/select' import { DEFAULT_VALUE_MAX_LEN } from '@/config' import Button from '@/app/components/base/button' -import Tooltip from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import TextGenerationImageUploader from '@/app/components/base/image-uploader/text-generation-image-uploader' import type { VisionFile, VisionSettings } from '@/types/app' @@ -207,6 +207,7 @@ const PromptValuePanel: FC = ({ {canNotRun ? ( {renderRunButton()} ) diff --git a/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx b/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx index b2c6792107..111c380afc 100644 --- a/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx +++ b/web/app/components/app/configuration/toolbox/annotation/annotation-ctrl-btn/index.tsx @@ -8,7 +8,7 @@ import { MessageCheckRemove, MessageFastPlus } from '@/app/components/base/icons import { MessageFast } from '@/app/components/base/icons/src/vender/solid/communication' import { Edit04 } from '@/app/components/base/icons/src/vender/line/general' import RemoveAnnotationConfirmModal from '@/app/components/app/annotation/remove-annotation-confirm-modal' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { addAnnotation, delAnnotation } from '@/service/annotation' import Toast from '@/app/components/base/toast' import { useProviderContext } from '@/context/provider-context' @@ -99,8 +99,9 @@ const CacheCtrlBtn: FC = ({ ) : answer ? ( -
= ({ >
-
+
) : null } -
= ({ >
-
+
{title}
- {tooltip}
} - > - - + />
{children}
diff --git a/web/app/components/app/configuration/tools/index.tsx b/web/app/components/app/configuration/tools/index.tsx index 2ec8f55441..03525d7232 100644 --- a/web/app/components/app/configuration/tools/index.tsx +++ b/web/app/components/app/configuration/tools/index.tsx @@ -7,11 +7,10 @@ import { RiAddLine, RiArrowDownSLine, RiDeleteBinLine, - RiQuestionLine, } from '@remixicon/react' import ConfigContext from '@/context/debug-configuration' import Switch from '@/app/components/base/switch' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { Tool03 } from '@/app/components/base/icons/src/vender/solid/general' import { Settings01, @@ -107,9 +106,13 @@ const Tools = () => {
{t('appDebug.feature.tools.title')}
- {t('appDebug.feature.tools.tips')}}> - - + + {t('appDebug.feature.tools.tips')} + + } + /> { !expanded && !!externalDataToolsConfig.length && ( @@ -143,7 +146,7 @@ const Tools = () => { background={item.icon_background} />
{item.label}
-
{ > {item.variable}
-
+
{
{t('app.newApp.captionAppType')}
- {t('app.newApp.chatbotDescription')}
} @@ -120,9 +119,8 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => {
{t('app.types.chatbot')}
- - +
{t('app.newApp.completionDescription')}
@@ -143,9 +141,8 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => {
{t('app.newApp.completeApp')}
- - + {t('app.newApp.agentDescription')} } @@ -164,9 +161,8 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => {
{t('app.types.agent')}
-
- +
{t('app.newApp.workflowDescription')}
@@ -188,7 +184,7 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => {
{t('app.types.workflow')}
BETA -
+ {showChatBotType && ( diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 40d171761b..2258b69394 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -5,10 +5,9 @@ import useSWR from 'swr' import { HandThumbDownIcon, HandThumbUpIcon, - InformationCircleIcon, XMarkIcon, } from '@heroicons/react/24/outline' -import { RiEditFill } from '@remixicon/react' +import { RiEditFill, RiQuestionLine } from '@remixicon/react' import { get } from 'lodash-es' import InfiniteScroll from 'react-infinite-scroll-component' import dayjs from 'dayjs' @@ -20,7 +19,6 @@ import { useTranslation } from 'react-i18next' import s from './style.module.css' import VarPanel from './var-panel' import cn from '@/utils/classnames' -import { randomString } from '@/utils' import type { FeedbackFunc, Feedbacktype, IChatItem, SubmitAnnotationFunc } from '@/app/components/base/chat/chat/type' import type { Annotation, ChatConversationFullDetailResponse, ChatConversationGeneralDetail, ChatConversationsResponse, ChatMessage, ChatMessagesRequest, CompletionConversationFullDetailResponse, CompletionConversationGeneralDetail, CompletionConversationsResponse, LogAnnotation } from '@/models/log' import type { App } from '@/types/app' @@ -28,7 +26,6 @@ import Loading from '@/app/components/base/loading' import Drawer from '@/app/components/base/drawer' import Popover from '@/app/components/base/popover' import Chat from '@/app/components/base/chat/chat' -import Tooltip from '@/app/components/base/tooltip' import { ToastContext } from '@/app/components/base/toast' import { fetchChatConversationDetail, fetchChatMessages, fetchCompletionConversationDetail, updateLogMessageAnnotations, updateLogMessageFeedbacks } from '@/service/log' import { TONE_LIST } from '@/config' @@ -42,7 +39,7 @@ import MessageLogModal from '@/app/components/base/message-log-modal' import { useStore as useAppStore } from '@/app/components/app/store' import { useAppContext } from '@/context/app-context' import useTimestamp from '@/hooks/use-timestamp' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { CopyIcon } from '@/app/components/base/copy-icon' dayjs.extend(utc) @@ -346,11 +343,11 @@ function DetailPanel{isChatMode ? t('appLog.detail.conversationId') : t('appLog.detail.time')} {isChatMode && (
- +
{detail.id}
-
+
)} @@ -380,7 +377,7 @@ function DetailPanel {targetTone} - + } htmlContent={
@@ -641,13 +638,12 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => { return ( {`${t('appLog.detail.annotationTip', { user: annotation?.account?.name })} ${formatTime(annotation?.created_at || dayjs().unix(), 'MM-DD hh:mm A')}`} } - className={(isHighlight && !isChatMode) ? '' : '!hidden'} - selector={`highlight-${randomString(16)}`} + popupClassName={(isHighlight && !isChatMode) ? '' : '!hidden'} >
{value || '-'} diff --git a/web/app/components/app/overview/appCard.tsx b/web/app/components/app/overview/appCard.tsx index ea0b793857..0d0b95c0a0 100644 --- a/web/app/components/app/overview/appCard.tsx +++ b/web/app/components/app/overview/appCard.tsx @@ -27,10 +27,11 @@ import ShareQRCode from '@/app/components/base/qrcode' import SecretKeyButton from '@/app/components/develop/secret-key/secret-key-button' import type { AppDetailResponse } from '@/models/app' import { useAppContext } from '@/context/app-context' +import type { AppSSO } from '@/types/app' export type IAppCardProps = { className?: string - appInfo: AppDetailResponse + appInfo: AppDetailResponse & Partial cardType?: 'api' | 'webapp' customBgColor?: string onChangeStatus: (val: boolean) => Promise @@ -194,8 +195,7 @@ function AppCard({ )} {isApp && isCurrentWorkspaceManager && (
diff --git a/web/app/components/app/overview/embedded/index.tsx b/web/app/components/app/overview/embedded/index.tsx index 9e5d5af0da..fa023192c9 100644 --- a/web/app/components/app/overview/embedded/index.tsx +++ b/web/app/components/app/overview/embedded/index.tsx @@ -153,8 +153,7 @@ const Embedded = ({ siteInfo, isShow, onClose, appBaseUrl, accessToken, classNam
diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 8da9b9864f..02eff06c0a 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -4,21 +4,25 @@ import React, { useEffect, useState } from 'react' import { ChevronRightIcon } from '@heroicons/react/20/solid' import Link from 'next/link' import { Trans, useTranslation } from 'react-i18next' +import { useContextSelector } from 'use-context-selector' import s from './style.module.css' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' import AppIcon from '@/app/components/base/app-icon' +import Switch from '@/app/components/base/switch' import { SimpleSelect } from '@/app/components/base/select' import type { AppDetailResponse } from '@/models/app' -import type { AppIconType, Language } from '@/types/app' +import type { AppIconType, AppSSO, Language } from '@/types/app' import { useToastContext } from '@/app/components/base/toast' import { languages } from '@/i18n/language' +import Tooltip from '@/app/components/base/tooltip' +import AppContext from '@/context/app-context' import type { AppIconSelection } from '@/app/components/base/app-icon-picker' import AppIconPicker from '@/app/components/base/app-icon-picker' export type ISettingsModalProps = { isChat: boolean - appInfo: AppDetailResponse + appInfo: AppDetailResponse & Partial isShow: boolean defaultValue?: string onClose: () => void @@ -39,6 +43,7 @@ export type ConfigParams = { icon: string icon_background?: string show_workflow_steps: boolean + enable_sso?: boolean } const prefixSettings = 'appOverview.overview.appInfo.settings' @@ -50,6 +55,7 @@ const SettingsModal: FC = ({ onClose, onSave, }) => { + const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures) const { notify } = useToastContext() const [isShowMore, setIsShowMore] = useState(false) const { @@ -76,6 +82,7 @@ const SettingsModal: FC = ({ privacyPolicy: privacy_policy, customDisclaimer: custom_disclaimer, show_workflow_steps, + enable_sso: appInfo.enable_sso, }) const [language, setLanguage] = useState(default_language) const [saveLoading, setSaveLoading] = useState(false) @@ -98,6 +105,7 @@ const SettingsModal: FC = ({ privacyPolicy: privacy_policy, customDisclaimer: custom_disclaimer, show_workflow_steps, + enable_sso: appInfo.enable_sso, }) setLanguage(default_language) setAppIcon(icon_type === 'image' @@ -149,6 +157,7 @@ const SettingsModal: FC = ({ icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, show_workflow_steps: inputInfo.show_workflow_steps, + enable_sso: inputInfo.enable_sso, } await onSave?.(params) setSaveLoading(false) @@ -206,22 +215,43 @@ const SettingsModal: FC = ({ defaultValue={language} onSelect={item => setLanguage(item.value as Language)} /> - {(appInfo.mode === 'workflow' || appInfo.mode === 'advanced-chat') && <> -
{t(`${prefixSettings}.workflow.title`)}
- setInputInfo({ ...inputInfo, show_workflow_steps: item.value === 'true' })} - /> - } +
+

{t(`${prefixSettings}.workflow.title`)}

+
+
{t(`${prefixSettings}.workflow.subTitle`)}
+ setInputInfo({ ...inputInfo, show_workflow_steps: v })} + /> +
+

{t(`${prefixSettings}.workflow.showDesc`)}

+
+ {isChat && <>
{t(`${prefixSettings}.chatColorTheme`)}

{t(`${prefixSettings}.chatColorThemeDesc`)}

} + {systemFeatures.enable_web_sso_switch_component &&
+

{t(`${prefixSettings}.sso.label`)}

+
+
{t(`${prefixSettings}.sso.title`)}
+ {t(`${prefixSettings}.sso.tooltip`)}
+ } + asChild={false} + > + setInputInfo({ ...inputInfo, enable_sso: v })}> + +
+

{t(`${prefixSettings}.sso.description`)}

+
} {!isShowMore &&
setIsShowMore(true)}>
{t(`${prefixSettings}.more.entry`)}
diff --git a/web/app/components/app/store.ts b/web/app/components/app/store.ts index a89b96d65d..0209102372 100644 --- a/web/app/components/app/store.ts +++ b/web/app/components/app/store.ts @@ -1,9 +1,9 @@ import { create } from 'zustand' -import type { App } from '@/types/app' +import type { App, AppSSO } from '@/types/app' import type { IChatItem } from '@/app/components/base/chat/chat/type' type State = { - appDetail?: App + appDetail?: App & Partial appSidebarExpand: string currentLogItem?: IChatItem currentLogModalActiveTab: string @@ -13,7 +13,7 @@ type State = { } type Action = { - setAppDetail: (appDetail?: App) => void + setAppDetail: (appDetail?: App & Partial) => void setAppSiderbarExpand: (state: string) => void setCurrentLogItem: (item?: IChatItem) => void setCurrentLogModalActiveTab: (tab: string) => void diff --git a/web/app/components/base/audio-btn/index.tsx b/web/app/components/base/audio-btn/index.tsx index 675f58b530..d57c79b571 100644 --- a/web/app/components/base/audio-btn/index.tsx +++ b/web/app/components/base/audio-btn/index.tsx @@ -83,25 +83,25 @@ const AudioBtn = ({ }[audioState] return ( -
+
diff --git a/web/app/components/base/chat/chat/answer/operation.tsx b/web/app/components/base/chat/chat/answer/operation.tsx index 8ec5c0f3b2..f54e736826 100644 --- a/web/app/components/base/chat/chat/answer/operation.tsx +++ b/web/app/components/base/chat/chat/answer/operation.tsx @@ -17,7 +17,7 @@ import { ThumbsDown, ThumbsUp, } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import Log from '@/app/components/base/chat/chat/log' type OperationProps = { @@ -162,28 +162,34 @@ const Operation: FC = ({ { config?.supportFeedback && !localFeedback?.rating && onFeedback && !isOpeningStatement && (
- +
handleFeedback('like')} >
-
- + +
handleFeedback('dislike')} >
-
+
) } { config?.supportFeedback && localFeedback?.rating && onFeedback && !isOpeningStatement && ( - +
= ({ ) }
-
+ ) }
diff --git a/web/app/components/base/chat/chat/chat-input.tsx b/web/app/components/base/chat/chat/chat-input.tsx index 0c083157a1..c4578fab62 100644 --- a/web/app/components/base/chat/chat/chat-input.tsx +++ b/web/app/components/base/chat/chat/chat-input.tsx @@ -17,7 +17,7 @@ import { TransferMethod } from '../types' import { useChatWithHistoryContext } from '../chat-with-history/context' import type { Theme } from '../embedded-chatbot/theme/theme-context' import { CssTransform } from '../embedded-chatbot/theme/utils' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { ToastContext } from '@/app/components/base/toast' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import VoiceInput from '@/app/components/base/voice-input' @@ -220,7 +220,7 @@ const ChatInput: FC = ({ {isMobile ? sendBtn : ( -
{t('common.operation.send')} Enter
@@ -229,7 +229,7 @@ const ChatInput: FC = ({ } > {sendBtn} -
+ )}
{ diff --git a/web/app/components/base/chat/embedded-chatbot/header.tsx b/web/app/components/base/chat/embedded-chatbot/header.tsx index c35f98e3f2..a5c74434c6 100644 --- a/web/app/components/base/chat/embedded-chatbot/header.tsx +++ b/web/app/components/base/chat/embedded-chatbot/header.tsx @@ -41,9 +41,7 @@ const Header: FC = ({
{ onCreateNewChat?.() diff --git a/web/app/components/base/chat/embedded-chatbot/index.tsx b/web/app/components/base/chat/embedded-chatbot/index.tsx index d34fe164d1..480adaae2d 100644 --- a/web/app/components/base/chat/embedded-chatbot/index.tsx +++ b/web/app/components/base/chat/embedded-chatbot/index.tsx @@ -88,9 +88,7 @@ const Chatbot = () => { {!isMobile && (
diff --git a/web/app/components/base/copy-btn/index.tsx b/web/app/components/base/copy-btn/index.tsx index a03b991057..2acb5d8e76 100644 --- a/web/app/components/base/copy-btn/index.tsx +++ b/web/app/components/base/copy-btn/index.tsx @@ -1,10 +1,9 @@ 'use client' -import { useRef, useState } from 'react' +import { useState } from 'react' import { t } from 'i18next' import copy from 'copy-to-clipboard' import s from './style.module.css' import Tooltip from '@/app/components/base/tooltip' -import { randomString } from '@/utils' type ICopyBtnProps = { value: string @@ -18,14 +17,11 @@ const CopyBtn = ({ isPlain, }: ICopyBtnProps) => { const [isCopied, setIsCopied] = useState(false) - const selector = useRef(`copy-tooltip-${randomString(4)}`) return (
{ +const CopyFeedback = ({ content, className }: Props) => { const { t } = useTranslation() const [isCopied, setIsCopied] = useState(false) @@ -30,8 +28,7 @@ const CopyFeedback = ({ content, selectorId, className }: Props) => { return ( { className={`w-8 h-8 cursor-pointer hover:bg-gray-100 rounded-lg ${ className ?? '' }`} - onMouseLeave={onMouseLeave} >
- +
) } diff --git a/web/app/components/base/copy-icon/index.tsx b/web/app/components/base/copy-icon/index.tsx index 0eb0356ffe..425a9ad293 100644 --- a/web/app/components/base/copy-icon/index.tsx +++ b/web/app/components/base/copy-icon/index.tsx @@ -3,7 +3,7 @@ import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { debounce } from 'lodash-es' import copy from 'copy-to-clipboard' -import TooltipPlus from '../tooltip-plus' +import Tooltip from '../tooltip' import { Clipboard, ClipboardCheck, @@ -29,7 +29,7 @@ export const CopyIcon = ({ content }: Props) => { }, 100) return ( - { ) }
- +
) } diff --git a/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx b/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx index e424c4ead5..e6d0b6e7e0 100644 --- a/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx +++ b/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx @@ -2,11 +2,8 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import { MessageSmileSquare } from '@/app/components/base/icons/src/vender/solid/communication' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' const SuggestedQuestionsAfterAnswer: FC = () => { const { t } = useTranslation() @@ -18,9 +15,7 @@ const SuggestedQuestionsAfterAnswer: FC = () => {
{t('appDebug.feature.suggestedQuestionsAfterAnswer.title')}
- - - +
{t('appDebug.feature.suggestedQuestionsAfterAnswer.resDes')}
diff --git a/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx b/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx index a5a2eb7bb7..e923d9a333 100644 --- a/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx +++ b/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx @@ -2,9 +2,6 @@ import useSWR from 'swr' import produce from 'immer' import React, { Fragment } from 'react' -import { - RiQuestionLine, -} from '@remixicon/react' import { usePathname } from 'next/navigation' import { useTranslation } from 'react-i18next' import { Listbox, Transition } from '@headlessui/react' @@ -74,13 +71,16 @@ const VoiceParamConfig = ({
{t('appDebug.voice.voiceSettings.language')}
- - {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( -
{item}
- ))} -
} selector='config-resolution-tooltip'> - - + + {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( +
{item} +
+ ))} +
+ } + />
= ({ )} {item.progress === -1 && ( - - + )}
)} diff --git a/web/app/components/base/param-item/index.tsx b/web/app/components/base/param-item/index.tsx index 20beb7a9a1..318b0fb0e3 100644 --- a/web/app/components/base/param-item/index.tsx +++ b/web/app/components/base/param-item/index.tsx @@ -1,10 +1,6 @@ 'use client' import type { FC } from 'react' -import { - RiQuestionLine, -} from '@remixicon/react' - -import Tooltip from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import Slider from '@/app/components/base/slider' import Switch from '@/app/components/base/switch' @@ -40,9 +36,9 @@ const ParamItem: FC = ({ className, id, name, noTooltip, tip, step = 0.1, )} {name} {!noTooltip && ( - {tip}
}> - -
+ {tip}
} + /> )}
diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx index 39193fc31d..65f3dad3a2 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx @@ -25,7 +25,7 @@ import { BubbleX, Env } from '@/app/components/base/icons/src/vender/line/others import { VarBlockIcon } from '@/app/components/workflow/block-icon' import { Line3 } from '@/app/components/base/icons/src/public/common' import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' type WorkflowVariableBlockComponentProps = { nodeKey: string @@ -113,9 +113,9 @@ const WorkflowVariableBlockComponent = ({ if (!node && !isEnv && !isChatVar) { return ( - + {Item} - +
) } diff --git a/web/app/components/base/qrcode/index.tsx b/web/app/components/base/qrcode/index.tsx index 721f8c9029..c9323992e9 100644 --- a/web/app/components/base/qrcode/index.tsx +++ b/web/app/components/base/qrcode/index.tsx @@ -2,8 +2,8 @@ import React, { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import QRCode from 'qrcode.react' -import Tooltip from '../tooltip' import QrcodeStyle from './style.module.css' +import Tooltip from '@/app/components/base/tooltip' type Props = { content: string @@ -51,8 +51,7 @@ const ShareQRCode = ({ content, selectorId, className }: Props) => { return (
-) - -const Tooltip: FC = ({ - position = 'top', - triggerMethod = 'hover', - disabled = false, - popupContent, - children, - hideArrow, - popupClassName, - offset, - asChild, -}) => { - const [open, setOpen] = useState(false) - const [isHoverPopup, { - setTrue: setHoverPopup, - setFalse: setNotHoverPopup, - }] = useBoolean(false) - - const isHoverPopupRef = useRef(isHoverPopup) - useEffect(() => { - isHoverPopupRef.current = isHoverPopup - }, [isHoverPopup]) - - const [isHoverTrigger, { - setTrue: setHoverTrigger, - setFalse: setNotHoverTrigger, - }] = useBoolean(false) - - const isHoverTriggerRef = useRef(isHoverTrigger) - useEffect(() => { - isHoverTriggerRef.current = isHoverTrigger - }, [isHoverTrigger]) - - const handleLeave = (isTrigger: boolean) => { - if (isTrigger) - setNotHoverTrigger() - - else - setNotHoverPopup() - - // give time to move to the popup - setTimeout(() => { - if (!isHoverPopupRef.current && !isHoverTriggerRef.current) - setOpen(false) - }, 500) - } - - return ( - - triggerMethod === 'click' && setOpen(v => !v)} - onMouseEnter={() => { - if (triggerMethod === 'hover') { - setHoverTrigger() - setOpen(true) - } - }} - onMouseLeave={() => triggerMethod === 'hover' && handleLeave(true)} - asChild={asChild} - > - {children} - - -
triggerMethod === 'hover' && setHoverPopup()} - onMouseLeave={() => triggerMethod === 'hover' && handleLeave(false)} - > - {popupContent} - {!hideArrow && arrow} -
-
-
- ) -} - -export default React.memo(Tooltip) diff --git a/web/app/components/base/tooltip/index.tsx b/web/app/components/base/tooltip/index.tsx index e7795c6537..f3b4cff132 100644 --- a/web/app/components/base/tooltip/index.tsx +++ b/web/app/components/base/tooltip/index.tsx @@ -1,52 +1,112 @@ 'use client' import type { FC } from 'react' -import React from 'react' -import { Tooltip as ReactTooltip } from 'react-tooltip' // fixed version to 5.8.3 https://github.com/ReactTooltip/react-tooltip/issues/972 -import classNames from '@/utils/classnames' -import 'react-tooltip/dist/react-tooltip.css' - -type TooltipProps = { - selector: string - content?: string +import React, { useEffect, useRef, useState } from 'react' +import { useBoolean } from 'ahooks' +import type { OffsetOptions, Placement } from '@floating-ui/react' +import { RiQuestionLine } from '@remixicon/react' +import cn from '@/utils/classnames' +import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' +export type TooltipProps = { + position?: Placement + triggerMethod?: 'hover' | 'click' + triggerClassName?: string disabled?: boolean - htmlContent?: React.ReactNode - className?: string // This should use !impornant to override the default styles eg: '!bg-white' - position?: 'top' | 'right' | 'bottom' | 'left' - clickable?: boolean - children: React.ReactNode - noArrow?: boolean + popupContent?: React.ReactNode + children?: React.ReactNode + popupClassName?: string + offset?: OffsetOptions + needsDelay?: boolean + asChild?: boolean } const Tooltip: FC = ({ - selector, - content, - disabled, position = 'top', + triggerMethod = 'hover', + triggerClassName, + disabled = false, + popupContent, children, - htmlContent, - className, - clickable, - noArrow, + popupClassName, + offset, + asChild = true, + needsDelay = false, }) => { + const [open, setOpen] = useState(false) + const [isHoverPopup, { + setTrue: setHoverPopup, + setFalse: setNotHoverPopup, + }] = useBoolean(false) + + const isHoverPopupRef = useRef(isHoverPopup) + useEffect(() => { + isHoverPopupRef.current = isHoverPopup + }, [isHoverPopup]) + + const [isHoverTrigger, { + setTrue: setHoverTrigger, + setFalse: setNotHoverTrigger, + }] = useBoolean(false) + + const isHoverTriggerRef = useRef(isHoverTrigger) + useEffect(() => { + isHoverTriggerRef.current = isHoverTrigger + }, [isHoverTrigger]) + + const handleLeave = (isTrigger: boolean) => { + if (isTrigger) + setNotHoverTrigger() + + else + setNotHoverPopup() + + // give time to move to the popup + if (needsDelay) { + setTimeout(() => { + if (!isHoverPopupRef.current && !isHoverTriggerRef.current) + setOpen(false) + }, 500) + } + else { + setOpen(false) + } + } + return ( -
- {React.cloneElement(children as React.ReactElement, { - 'data-tooltip-id': selector, - }) - } - + triggerMethod === 'click' && setOpen(v => !v)} + onMouseEnter={() => { + if (triggerMethod === 'hover') { + setHoverTrigger() + setOpen(true) + } + }} + onMouseLeave={() => triggerMethod === 'hover' && handleLeave(true)} + asChild={asChild} > - {htmlContent && htmlContent} - -
+ {children ||
} + + + {popupContent && (
triggerMethod === 'hover' && setHoverPopup()} + onMouseLeave={() => triggerMethod === 'hover' && handleLeave(false)} + > + {popupContent} +
)} +
+ ) } -export default Tooltip +export default React.memo(Tooltip) diff --git a/web/app/components/billing/pricing/plan-item.tsx b/web/app/components/billing/pricing/plan-item.tsx index 87a20437c3..b6ac17472e 100644 --- a/web/app/components/billing/pricing/plan-item.tsx +++ b/web/app/components/billing/pricing/plan-item.tsx @@ -2,14 +2,11 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import { useContext } from 'use-context-selector' import { Plan } from '../type' import { ALL_PLANS, NUM_INFINITE, contactSalesUrl, contractSales, unAvailable } from '../config' import Toast from '../../base/toast' -import TooltipPlus from '../../base/tooltip-plus' +import Tooltip from '../../base/tooltip' import { PlanRange } from './select-plan-range' import cn from '@/utils/classnames' import { useAppContext } from '@/context/app-context' @@ -30,13 +27,11 @@ const KeyValue = ({ label, value, tooltip }: { label: string; value: string | nu
{label}
{tooltip && ( - {tooltip}
} - > - - + /> )}
{value}
@@ -136,25 +131,21 @@ const PlanItem: FC = ({
+
{t('billing.plansCommon.supportItems.llmLoadingBalancing')}
- {t('billing.plansCommon.supportItems.llmLoadingBalancingTooltip')}
} - > - - + />
+
 {t('billing.plansCommon.supportItems.ragAPIRequest')}
- {t('billing.plansCommon.ragAPIRequestTooltip')}
} - > - - + />
{comingSoon}
diff --git a/web/app/components/billing/priority-label/index.tsx b/web/app/components/billing/priority-label/index.tsx index d8ad27b6e0..36338cf4a8 100644 --- a/web/app/components/billing/priority-label/index.tsx +++ b/web/app/components/billing/priority-label/index.tsx @@ -9,7 +9,7 @@ import { ZapFast, ZapNarrow, } from '@/app/components/base/icons/src/vender/solid/general' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' const PriorityLabel = () => { const { t } = useTranslation() @@ -27,7 +27,7 @@ const PriorityLabel = () => { }, [plan]) return ( -
{`${t('billing.plansCommon.documentProcessingPriority')}: ${t(`billing.plansCommon.priority.${priority}`)}`}
{ @@ -53,7 +53,7 @@ const PriorityLabel = () => { } {t(`billing.plansCommon.priority.${priority}`)} -
+ ) } diff --git a/web/app/components/billing/usage-info/index.tsx b/web/app/components/billing/usage-info/index.tsx index e929249584..ee41508ea6 100644 --- a/web/app/components/billing/usage-info/index.tsx +++ b/web/app/components/billing/usage-info/index.tsx @@ -2,7 +2,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { InfoCircle } from '../../base/icons/src/vender/line/general' import ProgressBar from '../progress-bar' import { NUM_INFINITE } from '../config' import Tooltip from '@/app/components/base/tooltip' @@ -48,11 +47,13 @@ const UsageInfo: FC = ({
{name}
{tooltip && ( - - {tooltip} - } selector='config-var-tooltip'> - - + + {tooltip} + + } + /> )}
diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 98676f2e83..54a0963f59 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -2,15 +2,13 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' + import cn from '@/utils/classnames' import TopKItem from '@/app/components/base/param-item/top-k-item' import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item' import { RETRIEVE_METHOD } from '@/types/app' import Switch from '@/app/components/base/switch' -import Tooltip from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import type { RetrievalConfig } from '@/types/app' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' @@ -114,9 +112,11 @@ const RetrievalParamConfig: FC = ({ )}
{t('common.modelProvider.rerankModel.key')} - {t('common.modelProvider.rerankModel.tip')}
}> - - + {t('common.modelProvider.rerankModel.tip')}
+ } + /> = ({
{option.label}
{option.tips}} - hideArrow - > - - + triggerClassName='ml-0.5 w-3.5 h-4.5' + /> )) } diff --git a/web/app/components/datasets/create/embedding-process/index.tsx b/web/app/components/datasets/create/embedding-process/index.tsx index 1e340d692f..574dd083c7 100644 --- a/web/app/components/datasets/create/embedding-process/index.tsx +++ b/web/app/components/datasets/create/embedding-process/index.tsx @@ -22,7 +22,7 @@ import { Plan } from '@/app/components/billing/type' import { ZapFast } from '@/app/components/base/icons/src/vender/solid/general' import UpgradeBtn from '@/app/components/billing/upgrade-btn' import { useProviderContext } from '@/context/provider-context' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { sleep } from '@/utils' type Props = { @@ -259,16 +259,18 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index
{`${getSourcePercent(indexingStatusDetail)}%`}
)} {indexingStatusDetail.indexing_status === 'error' && indexingStatusDetail.error && ( - - {indexingStatusDetail.error} - - )}> + + {indexingStatusDetail.error} + + )} + >
Error
-
+
)} {indexingStatusDetail.indexing_status === 'error' && !indexingStatusDetail.error && (
diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 3849f817d6..1b56241191 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -7,7 +7,6 @@ import { XMarkIcon } from '@heroicons/react/20/solid' import { RocketLaunchIcon } from '@heroicons/react/24/outline' import { RiCloseLine, - RiQuestionLine, } from '@remixicon/react' import Link from 'next/link' import { groupBy } from 'lodash-es' @@ -43,7 +42,6 @@ import { IS_CE_EDITION } from '@/config' import { RETRIEVE_METHOD } from '@/types/app' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import Tooltip from '@/app/components/base/tooltip' -import TooltipPlus from '@/app/components/base/tooltip-plus' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { LanguagesSupported } from '@/i18n/language' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' @@ -556,7 +554,7 @@ const StepTwo = ({ className='border-[0.5px] !h-8 hover:outline hover:outline-[0.5px] hover:outline-gray-300 text-gray-700 font-medium bg-white shadow-[0px_1px_2px_0px_rgba(16,24,40,0.05)]' onClick={setShowPreview} > - +
{t('datasetCreation.stepTwo.previewTitleButton')} @@ -628,13 +626,13 @@ const StepTwo = ({
{t('datasetCreation.stepTwo.overlap')} - - {t('datasetCreation.stepTwo.overlapTip')} -
- }> - - + + {t('datasetCreation.stepTwo.overlapTip')} +
+ } + />
= ({
{label}
{isRequired && *} {tooltip && ( - {tooltip}
- }> - - + {tooltip} + } + popupClassName='relative top-[3px] w-3 h-3 ml-1' + /> )} +