diff --git a/api/.ruff.toml b/api/.ruff.toml index 9668dc9f76..9a15754d9a 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -45,6 +45,7 @@ select = [ "G001", # don't use str format to logging messages "G003", # don't use + in logging messages "G004", # don't use f-strings to format logging messages + "UP042", # use StrEnum ] ignore = [ diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 9a8e840554..1400ee7085 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -1,4 +1,5 @@ from flask import Blueprint +from flask_restx import Namespace from libs.external_api import ExternalApi @@ -26,7 +27,16 @@ from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi bp = Blueprint("console", __name__, url_prefix="/console/api") -api = ExternalApi(bp) + +api = ExternalApi( + bp, + version="1.0", + title="Console API", + description="Console management APIs for app configuration, monitoring, and administration", +) + +# Create namespace +console_ns = Namespace("console", description="Console management API operations", path="/") # File api.add_resource(FileApi, "/files/upload") @@ -43,7 +53,16 @@ api.add_resource(AppImportConfirmApi, "/apps/imports//confirm" api.add_resource(AppImportCheckDependenciesApi, "/apps/imports//check-dependencies") # Import other controllers -from . import admin, apikey, extension, feature, ping, setup, version # pyright: ignore[reportUnusedImport] +from . import ( + admin, # pyright: ignore[reportUnusedImport] + apikey, # pyright: ignore[reportUnusedImport] + extension, # pyright: ignore[reportUnusedImport] + feature, # pyright: ignore[reportUnusedImport] + init_validate, # pyright: ignore[reportUnusedImport] + ping, # pyright: ignore[reportUnusedImport] + setup, # pyright: ignore[reportUnusedImport] + version, # pyright: ignore[reportUnusedImport] +) # Import app controllers from .app import ( @@ -103,6 +122,23 @@ from .explore import ( saved_message, # pyright: ignore[reportUnusedImport] ) +# Import tag controllers +from .tag import tags # pyright: ignore[reportUnusedImport] + +# Import workspace controllers +from .workspace import ( + account, # pyright: ignore[reportUnusedImport] + agent_providers, # pyright: ignore[reportUnusedImport] + endpoint, # pyright: ignore[reportUnusedImport] + load_balancing_config, # pyright: ignore[reportUnusedImport] + members, # pyright: ignore[reportUnusedImport] + model_providers, # pyright: ignore[reportUnusedImport] + models, # pyright: ignore[reportUnusedImport] + plugin, # pyright: ignore[reportUnusedImport] + tool_providers, # pyright: ignore[reportUnusedImport] + workspace, # pyright: ignore[reportUnusedImport] +) + # Explore Audio api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") @@ -174,19 +210,4 @@ api.add_resource( InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" ) -# Import tag controllers -from .tag import tags # pyright: ignore[reportUnusedImport] - -# Import workspace controllers -from .workspace import ( - account, # pyright: ignore[reportUnusedImport] - agent_providers, # pyright: ignore[reportUnusedImport] - endpoint, # pyright: ignore[reportUnusedImport] - load_balancing_config, # pyright: ignore[reportUnusedImport] - members, # pyright: ignore[reportUnusedImport] - model_providers, # pyright: ignore[reportUnusedImport] - models, # pyright: ignore[reportUnusedImport] - plugin, # pyright: ignore[reportUnusedImport] - tool_providers, # pyright: ignore[reportUnusedImport] - workspace, # pyright: ignore[reportUnusedImport] -) +api.add_namespace(console_ns) diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 1306efacf4..93f242ad28 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -3,7 +3,7 @@ from functools import wraps from typing import ParamSpec, TypeVar from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized @@ -12,7 +12,7 @@ P = ParamSpec("P") R = TypeVar("R") from configs import dify_config from constants.languages import supported_language -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import only_edition_cloud from extensions.ext_database import db from models.model import App, InstalledApp, RecommendedApp @@ -45,7 +45,28 @@ def admin_required(view: Callable[P, R]): return decorated +@console_ns.route("/admin/insert-explore-apps") class InsertExploreAppListApi(Resource): + @api.doc("insert_explore_app") + @api.doc(description="Insert or update an app in the explore list") + @api.expect( + api.model( + "InsertExploreAppRequest", + { + "app_id": fields.String(required=True, description="Application ID"), + "desc": fields.String(description="App description"), + "copyright": fields.String(description="Copyright information"), + "privacy_policy": fields.String(description="Privacy policy"), + "custom_disclaimer": fields.String(description="Custom disclaimer"), + "language": fields.String(required=True, description="Language code"), + "category": fields.String(required=True, description="App category"), + "position": fields.Integer(required=True, description="Display position"), + }, + ) + ) + @api.response(200, "App updated successfully") + @api.response(201, "App inserted successfully") + @api.response(404, "App not found") @only_edition_cloud @admin_required def post(self): @@ -115,7 +136,12 @@ class InsertExploreAppListApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/admin/insert-explore-apps/") class InsertExploreAppApi(Resource): + @api.doc("delete_explore_app") + @api.doc(description="Remove an app from the explore list") + @api.doc(params={"app_id": "Application ID to remove"}) + @api.response(204, "App removed successfully") @only_edition_cloud @admin_required def delete(self, app_id): @@ -152,7 +178,3 @@ class InsertExploreAppApi(Resource): db.session.commit() return {"result": "success"}, 204 - - -api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps") -api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/") diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 58a1d437d1..06de2fa6b6 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -14,7 +14,7 @@ from libs.login import login_required from models.dataset import Dataset from models.model import ApiToken, App -from . import api +from . import api, console_ns from .wraps import account_initialization_required, setup_required api_key_fields = { @@ -135,7 +135,25 @@ class BaseApiKeyResource(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//api-keys") class AppApiKeyListResource(BaseApiKeyListResource): + @api.doc("get_app_api_keys") + @api.doc(description="Get all API keys for an app") + @api.doc(params={"resource_id": "App ID"}) + @api.response(200, "Success", api_key_list) + def get(self, resource_id): + """Get all API keys for an app""" + return super().get(resource_id) + + @api.doc("create_app_api_key") + @api.doc(description="Create a new API key for an app") + @api.doc(params={"resource_id": "App ID"}) + @api.response(201, "API key created successfully", api_key_fields) + @api.response(400, "Maximum keys exceeded") + def post(self, resource_id): + """Create a new API key for an app""" + return super().post(resource_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -147,7 +165,16 @@ class AppApiKeyListResource(BaseApiKeyListResource): token_prefix = "app-" +@console_ns.route("/apps//api-keys/") class AppApiKeyResource(BaseApiKeyResource): + @api.doc("delete_app_api_key") + @api.doc(description="Delete an API key for an app") + @api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"}) + @api.response(204, "API key deleted successfully") + def delete(self, resource_id, api_key_id): + """Delete an API key for an app""" + return super().delete(resource_id, api_key_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -158,7 +185,25 @@ class AppApiKeyResource(BaseApiKeyResource): resource_id_field = "app_id" +@console_ns.route("/datasets//api-keys") class DatasetApiKeyListResource(BaseApiKeyListResource): + @api.doc("get_dataset_api_keys") + @api.doc(description="Get all API keys for a dataset") + @api.doc(params={"resource_id": "Dataset ID"}) + @api.response(200, "Success", api_key_list) + def get(self, resource_id): + """Get all API keys for a dataset""" + return super().get(resource_id) + + @api.doc("create_dataset_api_key") + @api.doc(description="Create a new API key for a dataset") + @api.doc(params={"resource_id": "Dataset ID"}) + @api.response(201, "API key created successfully", api_key_fields) + @api.response(400, "Maximum keys exceeded") + def post(self, resource_id): + """Create a new API key for a dataset""" + return super().post(resource_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -170,7 +215,16 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): token_prefix = "ds-" +@console_ns.route("/datasets//api-keys/") class DatasetApiKeyResource(BaseApiKeyResource): + @api.doc("delete_dataset_api_key") + @api.doc(description="Delete an API key for a dataset") + @api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"}) + @api.response(204, "API key deleted successfully") + def delete(self, resource_id, api_key_id): + """Delete an API key for a dataset""" + return super().delete(resource_id, api_key_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -179,9 +233,3 @@ class DatasetApiKeyResource(BaseApiKeyResource): resource_type = "dataset" resource_model = Dataset resource_id_field = "dataset_id" - - -api.add_resource(AppApiKeyListResource, "/apps//api-keys") -api.add_resource(AppApiKeyResource, "/apps//api-keys/") -api.add_resource(DatasetApiKeyListResource, "/datasets//api-keys") -api.add_resource(DatasetApiKeyResource, "/datasets//api-keys/") diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index e82e403ec2..8cdadfb03c 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,8 +1,8 @@ from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from constants.languages import supported_language -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.error import AlreadyActivateError from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -10,14 +10,36 @@ from libs.helper import StrLen, email, extract_remote_ip, timezone from models.account import AccountStatus from services.account_service import AccountService, RegisterService +active_check_parser = reqparse.RequestParser() +active_check_parser.add_argument( + "workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID" +) +active_check_parser.add_argument( + "email", type=email, required=False, nullable=True, location="args", help="Email address" +) +active_check_parser.add_argument( + "token", type=str, required=True, nullable=False, location="args", help="Activation token" +) + +@console_ns.route("/activate/check") class ActivateCheckApi(Resource): + @api.doc("check_activation_token") + @api.doc(description="Check if activation token is valid") + @api.expect(active_check_parser) + @api.response( + 200, + "Success", + api.model( + "ActivationCheckResponse", + { + "is_valid": fields.Boolean(description="Whether token is valid"), + "data": fields.Raw(description="Activation data if valid"), + }, + ), + ) def get(self): - parser = reqparse.RequestParser() - parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args") - parser.add_argument("email", type=email, required=False, nullable=True, location="args") - parser.add_argument("token", type=str, required=True, nullable=False, location="args") - args = parser.parse_args() + args = active_check_parser.parse_args() workspaceId = args["workspace_id"] reg_email = args["email"] @@ -38,18 +60,36 @@ class ActivateCheckApi(Resource): return {"is_valid": False} +active_parser = reqparse.RequestParser() +active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") +active_parser.add_argument("email", type=email, required=False, nullable=True, location="json") +active_parser.add_argument("token", type=str, required=True, nullable=False, location="json") +active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") +active_parser.add_argument( + "interface_language", type=supported_language, required=True, nullable=False, location="json" +) +active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") + + +@console_ns.route("/activate") class ActivateApi(Resource): + @api.doc("activate_account") + @api.doc(description="Activate account with invitation token") + @api.expect(active_parser) + @api.response( + 200, + "Account activated successfully", + api.model( + "ActivationResponse", + { + "result": fields.String(description="Operation result"), + "data": fields.Raw(description="Login token data"), + }, + ), + ) + @api.response(400, "Already activated or invalid token") def post(self): - parser = reqparse.RequestParser() - parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") - parser.add_argument("email", type=email, required=False, nullable=True, location="json") - parser.add_argument("token", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") - parser.add_argument( - "interface_language", type=supported_language, required=True, nullable=False, location="json" - ) - parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") - args = parser.parse_args() + args = active_parser.parse_args() invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) if invitation is None: @@ -70,7 +110,3 @@ class ActivateApi(Resource): token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) return {"result": "success", "data": token_pair.model_dump()} - - -api.add_resource(ActivateCheckApi, "/activate/check") -api.add_resource(ActivateApi, "/activate") diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 8f57b3d03e..fc4ba3a2c7 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -3,11 +3,11 @@ import logging import requests from flask import current_app, redirect, request from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields from werkzeug.exceptions import Forbidden from configs import dify_config -from controllers.console import api +from controllers.console import api, console_ns from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -28,7 +28,21 @@ def get_oauth_providers(): return OAUTH_PROVIDERS +@console_ns.route("/oauth/data-source/") class OAuthDataSource(Resource): + @api.doc("oauth_data_source") + @api.doc(description="Get OAuth authorization URL for data source provider") + @api.doc(params={"provider": "Data source provider name (notion)"}) + @api.response( + 200, + "Authorization URL or internal setup success", + api.model( + "OAuthDataSourceResponse", + {"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")}, + ), + ) + @api.response(400, "Invalid provider") + @api.response(403, "Admin privileges required") def get(self, provider: str): # The role of the current user in the table must be admin or owner if not current_user.is_admin_or_owner: @@ -49,7 +63,19 @@ class OAuthDataSource(Resource): return {"data": auth_url}, 200 +@console_ns.route("/oauth/data-source/callback/") class OAuthDataSourceCallback(Resource): + @api.doc("oauth_data_source_callback") + @api.doc(description="Handle OAuth callback from data source provider") + @api.doc( + params={ + "provider": "Data source provider name (notion)", + "code": "Authorization code from OAuth provider", + "error": "Error message from OAuth provider", + } + ) + @api.response(302, "Redirect to console with result") + @api.response(400, "Invalid provider") def get(self, provider: str): OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() with current_app.app_context(): @@ -68,7 +94,19 @@ class OAuthDataSourceCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") +@console_ns.route("/oauth/data-source/binding/") class OAuthDataSourceBinding(Resource): + @api.doc("oauth_data_source_binding") + @api.doc(description="Bind OAuth data source with authorization code") + @api.doc( + params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"} + ) + @api.response( + 200, + "Data source binding success", + api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Invalid provider or code") def get(self, provider: str): OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() with current_app.app_context(): @@ -90,7 +128,17 @@ class OAuthDataSourceBinding(Resource): return {"result": "success"}, 200 +@console_ns.route("/oauth/data-source///sync") class OAuthDataSourceSync(Resource): + @api.doc("oauth_data_source_sync") + @api.doc(description="Sync data from OAuth data source") + @api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"}) + @api.response( + 200, + "Data source sync success", + api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Invalid provider or sync failed") @setup_required @login_required @account_initialization_required @@ -111,9 +159,3 @@ class OAuthDataSourceSync(Resource): return {"error": "OAuth data source process failed"}, 400 return {"result": "success"}, 200 - - -api.add_resource(OAuthDataSource, "/oauth/data-source/") -api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") -api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") -api.add_resource(OAuthDataSourceSync, "/oauth/data-source///sync") diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index ede0696854..7f34adc0f3 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,12 +2,12 @@ import base64 import secrets from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from constants.languages import languages -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.auth.error import ( EmailCodeError, EmailPasswordResetLimitError, @@ -28,7 +28,32 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces from services.feature_service import FeatureService +@console_ns.route("/forgot-password") class ForgotPasswordSendEmailApi(Resource): + @api.doc("send_forgot_password_email") + @api.doc(description="Send password reset email") + @api.expect( + api.model( + "ForgotPasswordEmailRequest", + { + "email": fields.String(required=True, description="Email address"), + "language": fields.String(description="Language for email (zh-Hans/en-US)"), + }, + ) + ) + @api.response( + 200, + "Email sent successfully", + api.model( + "ForgotPasswordEmailResponse", + { + "result": fields.String(description="Operation result"), + "data": fields.String(description="Reset token"), + "code": fields.String(description="Error code if account not found"), + }, + ), + ) + @api.response(400, "Invalid email or rate limit exceeded") @setup_required @email_password_login_enabled def post(self): @@ -61,7 +86,33 @@ class ForgotPasswordSendEmailApi(Resource): return {"result": "success", "data": token} +@console_ns.route("/forgot-password/validity") class ForgotPasswordCheckApi(Resource): + @api.doc("check_forgot_password_code") + @api.doc(description="Verify password reset code") + @api.expect( + api.model( + "ForgotPasswordCheckRequest", + { + "email": fields.String(required=True, description="Email address"), + "code": fields.String(required=True, description="Verification code"), + "token": fields.String(required=True, description="Reset token"), + }, + ) + ) + @api.response( + 200, + "Code verified successfully", + api.model( + "ForgotPasswordCheckResponse", + { + "is_valid": fields.Boolean(description="Whether code is valid"), + "email": fields.String(description="Email address"), + "token": fields.String(description="New reset token"), + }, + ), + ) + @api.response(400, "Invalid code or token") @setup_required @email_password_login_enabled def post(self): @@ -100,7 +151,26 @@ class ForgotPasswordCheckApi(Resource): return {"is_valid": True, "email": token_data.get("email"), "token": new_token} +@console_ns.route("/forgot-password/resets") class ForgotPasswordResetApi(Resource): + @api.doc("reset_password") + @api.doc(description="Reset password with verification token") + @api.expect( + api.model( + "ForgotPasswordResetRequest", + { + "token": fields.String(required=True, description="Verification token"), + "new_password": fields.String(required=True, description="New password"), + "password_confirm": fields.String(required=True, description="Password confirmation"), + }, + ) + ) + @api.response( + 200, + "Password reset successfully", + api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Invalid token or password mismatch") @setup_required @email_password_login_enabled def post(self): @@ -172,8 +242,3 @@ class ForgotPasswordResetApi(Resource): pass except AccountRegisterError: raise AccountInFreezeError() - - -api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") -api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") -api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 06151ee39b..c3c9de1589 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -22,7 +22,7 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.feature_service import FeatureService -from .. import api +from .. import api, console_ns logger = logging.getLogger(__name__) @@ -50,7 +50,13 @@ def get_oauth_providers(): return OAUTH_PROVIDERS +@console_ns.route("/oauth/login/") class OAuthLogin(Resource): + @api.doc("oauth_login") + @api.doc(description="Initiate OAuth login process") + @api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"}) + @api.response(302, "Redirect to OAuth authorization URL") + @api.response(400, "Invalid provider") def get(self, provider: str): invite_token = request.args.get("invite_token") or None OAUTH_PROVIDERS = get_oauth_providers() @@ -63,7 +69,19 @@ class OAuthLogin(Resource): return redirect(auth_url) +@console_ns.route("/oauth/authorize/") class OAuthCallback(Resource): + @api.doc("oauth_callback") + @api.doc(description="Handle OAuth callback and complete login process") + @api.doc( + params={ + "provider": "OAuth provider name (github/google)", + "code": "Authorization code from OAuth provider", + "state": "Optional state parameter (used for invite token)", + } + ) + @api.response(302, "Redirect to console with access token") + @api.response(400, "OAuth process failed") def get(self, provider: str): OAUTH_PROVIDERS = get_oauth_providers() with current_app.app_context(): @@ -184,7 +202,3 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): AccountService.link_account_integrate(provider, user_info.id, account) return account - - -api.add_resource(OAuthLogin, "/oauth/login/") -api.add_resource(OAuthCallback, "/oauth/authorize/") diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index e157041c35..57f5ab191e 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,8 +1,8 @@ from flask_login import current_user -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from constants import HIDDEN_VALUE -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields from libs.login import login_required @@ -11,7 +11,21 @@ from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService +@console_ns.route("/code-based-extension") class CodeBasedExtensionAPI(Resource): + @api.doc("get_code_based_extension") + @api.doc(description="Get code-based extension data by module name") + @api.expect( + api.parser().add_argument("module", type=str, required=True, location="args", help="Extension module name") + ) + @api.response( + 200, + "Success", + api.model( + "CodeBasedExtensionResponse", + {"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")}, + ), + ) @setup_required @login_required @account_initialization_required @@ -23,7 +37,11 @@ class CodeBasedExtensionAPI(Resource): return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} +@console_ns.route("/api-based-extension") class APIBasedExtensionAPI(Resource): + @api.doc("get_api_based_extensions") + @api.doc(description="Get all API-based extensions for current tenant") + @api.response(200, "Success", fields.List(fields.Nested(api_based_extension_fields))) @setup_required @login_required @account_initialization_required @@ -32,6 +50,19 @@ class APIBasedExtensionAPI(Resource): tenant_id = current_user.current_tenant_id return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + @api.doc("create_api_based_extension") + @api.doc(description="Create a new API-based extension") + @api.expect( + api.model( + "CreateAPIBasedExtensionRequest", + { + "name": fields.String(required=True, description="Extension name"), + "api_endpoint": fields.String(required=True, description="API endpoint URL"), + "api_key": fields.String(required=True, description="API key for authentication"), + }, + ) + ) + @api.response(201, "Extension created successfully", api_based_extension_fields) @setup_required @login_required @account_initialization_required @@ -53,7 +84,12 @@ class APIBasedExtensionAPI(Resource): return APIBasedExtensionService.save(extension_data) +@console_ns.route("/api-based-extension/") class APIBasedExtensionDetailAPI(Resource): + @api.doc("get_api_based_extension") + @api.doc(description="Get API-based extension by ID") + @api.doc(params={"id": "Extension ID"}) + @api.response(200, "Success", api_based_extension_fields) @setup_required @login_required @account_initialization_required @@ -64,6 +100,20 @@ class APIBasedExtensionDetailAPI(Resource): return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + @api.doc("update_api_based_extension") + @api.doc(description="Update API-based extension") + @api.doc(params={"id": "Extension ID"}) + @api.expect( + api.model( + "UpdateAPIBasedExtensionRequest", + { + "name": fields.String(required=True, description="Extension name"), + "api_endpoint": fields.String(required=True, description="API endpoint URL"), + "api_key": fields.String(required=True, description="API key for authentication"), + }, + ) + ) + @api.response(200, "Extension updated successfully", api_based_extension_fields) @setup_required @login_required @account_initialization_required @@ -88,6 +138,10 @@ class APIBasedExtensionDetailAPI(Resource): return APIBasedExtensionService.save(extension_data_from_db) + @api.doc("delete_api_based_extension") + @api.doc(description="Delete API-based extension") + @api.doc(params={"id": "Extension ID"}) + @api.response(204, "Extension deleted successfully") @setup_required @login_required @account_initialization_required @@ -100,9 +154,3 @@ class APIBasedExtensionDetailAPI(Resource): APIBasedExtensionService.delete(extension_data_from_db) return {"result": "success"}, 204 - - -api.add_resource(CodeBasedExtensionAPI, "/code-based-extension") - -api.add_resource(APIBasedExtensionAPI, "/api-based-extension") -api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/") diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 6236832d39..d43b839291 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,26 +1,40 @@ from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields from libs.login import login_required from services.feature_service import FeatureService -from . import api +from . import api, console_ns from .wraps import account_initialization_required, cloud_utm_record, setup_required +@console_ns.route("/features") class FeatureApi(Resource): + @api.doc("get_tenant_features") + @api.doc(description="Get feature configuration for current tenant") + @api.response( + 200, + "Success", + api.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}), + ) @setup_required @login_required @account_initialization_required @cloud_utm_record def get(self): + """Get feature configuration for current tenant""" return FeatureService.get_features(current_user.current_tenant_id).model_dump() +@console_ns.route("/system-features") class SystemFeatureApi(Resource): + @api.doc("get_system_features") + @api.doc(description="Get system-wide feature configuration") + @api.response( + 200, + "Success", + api.model("SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}), + ) def get(self): + """Get system-wide feature configuration""" return FeatureService.get_system_features().model_dump() - - -api.add_resource(FeatureApi, "/features") -api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 2a37b1708a..30b53458b2 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,7 +1,7 @@ import os from flask import session -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from sqlalchemy import select from sqlalchemy.orm import Session @@ -11,20 +11,47 @@ from libs.helper import StrLen from models.model import DifySetup from services.account_service import TenantService -from . import api +from . import api, console_ns from .error import AlreadySetupError, InitValidateFailedError from .wraps import only_edition_self_hosted +@console_ns.route("/init") class InitValidateAPI(Resource): + @api.doc("get_init_status") + @api.doc(description="Get initialization validation status") + @api.response( + 200, + "Success", + model=api.model( + "InitStatusResponse", + {"status": fields.String(description="Initialization status", enum=["finished", "not_started"])}, + ), + ) def get(self): + """Get initialization validation status""" init_status = get_init_validate_status() if init_status: return {"status": "finished"} return {"status": "not_started"} + @api.doc("validate_init_password") + @api.doc(description="Validate initialization password for self-hosted edition") + @api.expect( + api.model( + "InitValidateRequest", + {"password": fields.String(required=True, description="Initialization password", max_length=30)}, + ) + ) + @api.response( + 201, + "Success", + model=api.model("InitValidateResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Already setup or validation failed") @only_edition_self_hosted def post(self): + """Validate initialization password""" # is tenant created tenant_count = TenantService.get_tenant_count() if tenant_count > 0: @@ -52,6 +79,3 @@ def get_init_validate_status(): return db_session.execute(select(DifySetup)).scalar_one_or_none() return True - - -api.add_resource(InitValidateAPI, "/init") diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index 1a53a2347e..29f49b99de 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -1,14 +1,17 @@ -from flask_restx import Resource +from flask_restx import Resource, fields -from controllers.console import api +from . import api, console_ns +@console_ns.route("/ping") class PingApi(Resource): + @api.doc("health_check") + @api.doc(description="Health check endpoint for connection testing") + @api.response( + 200, + "Success", + api.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}), + ) def get(self): - """ - For connection health check - """ + """Health check endpoint for connection testing""" return {"result": "pong"} - - -api.add_resource(PingApi, "/ping") diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 8e230496f0..bff5fc1651 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,5 +1,5 @@ from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from configs import dify_config from libs.helper import StrLen, email, extract_remote_ip @@ -7,23 +7,56 @@ from libs.password import valid_password from models.model import DifySetup, db from services.account_service import RegisterService, TenantService -from . import api +from . import api, console_ns from .error import AlreadySetupError, NotInitValidateError from .init_validate import get_init_validate_status from .wraps import only_edition_self_hosted +@console_ns.route("/setup") class SetupApi(Resource): + @api.doc("get_setup_status") + @api.doc(description="Get system setup status") + @api.response( + 200, + "Success", + api.model( + "SetupStatusResponse", + { + "step": fields.String(description="Setup step status", enum=["not_started", "finished"]), + "setup_at": fields.String(description="Setup completion time (ISO format)", required=False), + }, + ), + ) def get(self): + """Get system setup status""" if dify_config.EDITION == "SELF_HOSTED": setup_status = get_setup_status() - if setup_status: + # Check if setup_status is a DifySetup object rather than a bool + if setup_status and not isinstance(setup_status, bool): return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()} + elif setup_status: + return {"step": "finished"} return {"step": "not_started"} return {"step": "finished"} + @api.doc("setup_system") + @api.doc(description="Initialize system setup with admin account") + @api.expect( + api.model( + "SetupRequest", + { + "email": fields.String(required=True, description="Admin email address"), + "name": fields.String(required=True, description="Admin name (max 30 characters)"), + "password": fields.String(required=True, description="Admin password"), + }, + ) + ) + @api.response(201, "Success", api.model("SetupResponse", {"result": fields.String(description="Setup result")})) + @api.response(400, "Already setup or validation failed") @only_edition_self_hosted def post(self): + """Initialize system setup with admin account""" # is set up if get_setup_status(): raise AlreadySetupError() @@ -55,6 +88,3 @@ def get_setup_status(): return db.session.query(DifySetup).first() else: return True - - -api.add_resource(SetupApi, "/setup") diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 8409e7d1ab..8d081ad995 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -2,18 +2,41 @@ import json import logging import requests -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from packaging import version from configs import dify_config -from . import api +from . import api, console_ns logger = logging.getLogger(__name__) +@console_ns.route("/version") class VersionApi(Resource): + @api.doc("check_version_update") + @api.doc(description="Check for application version updates") + @api.expect( + api.parser().add_argument( + "current_version", type=str, required=True, location="args", help="Current application version" + ) + ) + @api.response( + 200, + "Success", + api.model( + "VersionResponse", + { + "version": fields.String(description="Latest version number"), + "release_date": fields.String(description="Release date of latest version"), + "release_notes": fields.String(description="Release notes for latest version"), + "can_auto_update": fields.Boolean(description="Whether auto-update is supported"), + "features": fields.Raw(description="Feature flags and capabilities"), + }, + ), + ) def get(self): + """Check for application version updates""" parser = reqparse.RequestParser() parser.add_argument("current_version", type=str, required=True, location="args") args = parser.parse_args() @@ -59,6 +82,3 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool: except version.InvalidVersion: logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version) return False - - -api.add_resource(VersionApi, "/version") diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 08bab6fcb5..0a2c8fcfb4 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,14 +1,22 @@ from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from libs.login import login_required from services.agent_service import AgentService +@console_ns.route("/workspaces/current/agent-providers") class AgentProviderListApi(Resource): + @api.doc("list_agent_providers") + @api.doc(description="Get list of available agent providers") + @api.response( + 200, + "Success", + fields.List(fields.Raw(description="Agent provider information")), + ) @setup_required @login_required @account_initialization_required @@ -21,7 +29,16 @@ class AgentProviderListApi(Resource): return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id)) +@console_ns.route("/workspaces/current/agent-provider/") class AgentProviderApi(Resource): + @api.doc("get_agent_provider") + @api.doc(description="Get specific agent provider details") + @api.doc(params={"provider_name": "Agent provider name"}) + @api.response( + 200, + "Success", + fields.Raw(description="Agent provider details"), + ) @setup_required @login_required @account_initialization_required @@ -30,7 +47,3 @@ class AgentProviderApi(Resource): user_id = user.id tenant_id = user.current_tenant_id return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name)) - - -api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers") -api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/") diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 96e873d42b..0657b764cc 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,8 +1,8 @@ from flask_login import current_user -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginPermissionDeniedError @@ -10,7 +10,26 @@ from libs.login import login_required from services.plugin.endpoint_service import EndpointService +@console_ns.route("/workspaces/current/endpoints/create") class EndpointCreateApi(Resource): + @api.doc("create_endpoint") + @api.doc(description="Create a new plugin endpoint") + @api.expect( + api.model( + "EndpointCreateRequest", + { + "plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"), + "settings": fields.Raw(required=True, description="Endpoint settings"), + "name": fields.String(required=True, description="Endpoint name"), + }, + ) + ) + @api.response( + 200, + "Endpoint created successfully", + api.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -43,7 +62,20 @@ class EndpointCreateApi(Resource): raise ValueError(e.description) from e +@console_ns.route("/workspaces/current/endpoints/list") class EndpointListApi(Resource): + @api.doc("list_endpoints") + @api.doc(description="List plugin endpoints with pagination") + @api.expect( + api.parser() + .add_argument("page", type=int, required=True, location="args", help="Page number") + .add_argument("page_size", type=int, required=True, location="args", help="Page size") + ) + @api.response( + 200, + "Success", + api.model("EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}), + ) @setup_required @login_required @account_initialization_required @@ -70,7 +102,23 @@ class EndpointListApi(Resource): ) +@console_ns.route("/workspaces/current/endpoints/list/plugin") class EndpointListForSinglePluginApi(Resource): + @api.doc("list_plugin_endpoints") + @api.doc(description="List endpoints for a specific plugin") + @api.expect( + api.parser() + .add_argument("page", type=int, required=True, location="args", help="Page number") + .add_argument("page_size", type=int, required=True, location="args", help="Page size") + .add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID") + ) + @api.response( + 200, + "Success", + api.model( + "PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} + ), + ) @setup_required @login_required @account_initialization_required @@ -100,7 +148,19 @@ class EndpointListForSinglePluginApi(Resource): ) +@console_ns.route("/workspaces/current/endpoints/delete") class EndpointDeleteApi(Resource): + @api.doc("delete_endpoint") + @api.doc(description="Delete a plugin endpoint") + @api.expect( + api.model("EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) + ) + @api.response( + 200, + "Endpoint deleted successfully", + api.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -123,7 +183,26 @@ class EndpointDeleteApi(Resource): } +@console_ns.route("/workspaces/current/endpoints/update") class EndpointUpdateApi(Resource): + @api.doc("update_endpoint") + @api.doc(description="Update a plugin endpoint") + @api.expect( + api.model( + "EndpointUpdateRequest", + { + "endpoint_id": fields.String(required=True, description="Endpoint ID"), + "settings": fields.Raw(required=True, description="Updated settings"), + "name": fields.String(required=True, description="Updated name"), + }, + ) + ) + @api.response( + 200, + "Endpoint updated successfully", + api.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -154,7 +233,19 @@ class EndpointUpdateApi(Resource): } +@console_ns.route("/workspaces/current/endpoints/enable") class EndpointEnableApi(Resource): + @api.doc("enable_endpoint") + @api.doc(description="Enable a plugin endpoint") + @api.expect( + api.model("EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) + ) + @api.response( + 200, + "Endpoint enabled successfully", + api.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -177,7 +268,19 @@ class EndpointEnableApi(Resource): } +@console_ns.route("/workspaces/current/endpoints/disable") class EndpointDisableApi(Resource): + @api.doc("disable_endpoint") + @api.doc(description="Disable a plugin endpoint") + @api.expect( + api.model("EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) + ) + @api.response( + 200, + "Endpoint disabled successfully", + api.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -198,12 +301,3 @@ class EndpointDisableApi(Resource): tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id ) } - - -api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create") -api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list") -api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin") -api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete") -api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update") -api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable") -api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable") diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index a1b8bb7cfe..26fbf7097e 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="Files API", description="API for file operations including upload and preview", - doc="/docs", # Enable Swagger UI at /files/docs ) files_ns = Namespace("files", description="File operations", path="/") diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index b09c39309f..f29f624ba5 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="Inner API", description="Internal APIs for enterprise features, billing, and plugin communication", - doc="/docs", # Enable Swagger UI at /inner/api/docs ) # Create namespace diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 18b530f2c4..bde0150ffd 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -75,9 +75,6 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None): if not user_id: user_id = DEFAULT_SERVICE_API_USER_ID - del kwargs["tenant_id"] - del kwargs["user_id"] - try: tenant_model = ( db.session.query(Tenant) diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py index 43b36a70b4..336a7801bb 100644 --- a/api/controllers/mcp/__init__.py +++ b/api/controllers/mcp/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="MCP API", description="API for Model Context Protocol operations", - doc="/docs", # Enable Swagger UI at /mcp/docs ) mcp_ns = Namespace("mcp", description="MCP operations", path="/") diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index d69f49d957..a6008fdb99 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="Service API", description="API for application services", - doc="/docs", # Enable Swagger UI at /v1/docs ) service_api_ns = Namespace("service_api", description="Service operations", path="/") diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index a825a2a0d8..97bcd3d53c 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="Web API", description="Public APIs for web applications including file uploads, chat interactions, and app management", - doc="/docs", # Enable Swagger UI at /api/docs ) # Create namespace diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 2c0f6c9759..c1c46891b6 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -5,7 +5,7 @@ from flask_restx import fields, marshal_with, reqparse from werkzeug.exceptions import InternalServerError import services -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, AudioTooLargeError, @@ -32,15 +32,16 @@ from services.errors.audio import ( logger = logging.getLogger(__name__) +@web_ns.route("/audio-to-text") class AudioApi(WebApiResource): audio_to_text_response_fields = { "text": fields.String, } @marshal_with(audio_to_text_response_fields) - @api.doc("Audio to Text") - @api.doc(description="Convert audio file to text using speech-to-text service.") - @api.doc( + @web_ns.doc("Audio to Text") + @web_ns.doc(description="Convert audio file to text using speech-to-text service.") + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -85,6 +86,7 @@ class AudioApi(WebApiResource): raise InternalServerError() +@web_ns.route("/text-to-audio") class TextApi(WebApiResource): text_to_audio_response_fields = { "audio_url": fields.String, @@ -92,9 +94,9 @@ class TextApi(WebApiResource): } @marshal_with(text_to_audio_response_fields) - @api.doc("Text to Audio") - @api.doc(description="Convert text to audio using text-to-speech service.") - @api.doc( + @web_ns.doc("Text to Audio") + @web_ns.doc(description="Convert text to audio using text-to-speech service.") + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -145,7 +147,3 @@ class TextApi(WebApiResource): except Exception as e: logger.exception("Failed to handle post request to TextApi") raise InternalServerError() - - -api.add_resource(AudioApi, "/audio-to-text") -api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index a42bf5fc6e..67ae970388 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -4,7 +4,7 @@ from flask_restx import reqparse from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, CompletionRequestError, @@ -35,10 +35,11 @@ logger = logging.getLogger(__name__) # define completion api for user +@web_ns.route("/completion-messages") class CompletionApi(WebApiResource): - @api.doc("Create Completion Message") - @api.doc(description="Create a completion message for text generation applications.") - @api.doc( + @web_ns.doc("Create Completion Message") + @web_ns.doc(description="Create a completion message for text generation applications.") + @web_ns.doc( params={ "inputs": {"description": "Input variables for the completion", "type": "object", "required": True}, "query": {"description": "Query text for completion", "type": "string", "required": False}, @@ -52,7 +53,7 @@ class CompletionApi(WebApiResource): "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -106,11 +107,12 @@ class CompletionApi(WebApiResource): raise InternalServerError() +@web_ns.route("/completion-messages//stop") class CompletionStopApi(WebApiResource): - @api.doc("Stop Completion Message") - @api.doc(description="Stop a running completion message task.") - @api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) - @api.doc( + @web_ns.doc("Stop Completion Message") + @web_ns.doc(description="Stop a running completion message task.") + @web_ns.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -129,10 +131,11 @@ class CompletionStopApi(WebApiResource): return {"result": "success"}, 200 +@web_ns.route("/chat-messages") class ChatApi(WebApiResource): - @api.doc("Create Chat Message") - @api.doc(description="Create a chat message for conversational applications.") - @api.doc( + @web_ns.doc("Create Chat Message") + @web_ns.doc(description="Create a chat message for conversational applications.") + @web_ns.doc( params={ "inputs": {"description": "Input variables for the chat", "type": "object", "required": True}, "query": {"description": "User query/message", "type": "string", "required": True}, @@ -148,7 +151,7 @@ class ChatApi(WebApiResource): "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -207,11 +210,12 @@ class ChatApi(WebApiResource): raise InternalServerError() +@web_ns.route("/chat-messages//stop") class ChatStopApi(WebApiResource): - @api.doc("Stop Chat Message") - @api.doc(description="Stop a running chat message task.") - @api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) - @api.doc( + @web_ns.doc("Stop Chat Message") + @web_ns.doc(description="Stop a running chat message task.") + @web_ns.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -229,9 +233,3 @@ class ChatStopApi(WebApiResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {"result": "success"}, 200 - - -api.add_resource(CompletionApi, "/completion-messages") -api.add_resource(CompletionStopApi, "/completion-messages//stop") -api.add_resource(ChatApi, "/chat-messages") -api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 24de4f3f2e..03dd986aed 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -3,7 +3,7 @@ from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom @@ -16,7 +16,44 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers from services.web_conversation_service import WebConversationService +@web_ns.route("/conversations") class ConversationListApi(WebApiResource): + @web_ns.doc("Get Conversation List") + @web_ns.doc(description="Retrieve paginated list of conversations for a chat application.") + @web_ns.doc( + params={ + "last_id": {"description": "Last conversation ID for pagination", "type": "string", "required": False}, + "limit": { + "description": "Number of conversations to return (1-100)", + "type": "integer", + "required": False, + "default": 20, + }, + "pinned": { + "description": "Filter by pinned status", + "type": "string", + "enum": ["true", "false"], + "required": False, + }, + "sort_by": { + "description": "Sort order", + "type": "string", + "enum": ["created_at", "-created_at", "updated_at", "-updated_at"], + "required": False, + "default": "-updated_at", + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "App Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) @@ -57,11 +94,25 @@ class ConversationListApi(WebApiResource): raise NotFound("Last Conversation Not Exists.") +@web_ns.route("/conversations/") class ConversationApi(WebApiResource): delete_response_fields = { "result": fields.String, } + @web_ns.doc("Delete Conversation") + @web_ns.doc(description="Delete a specific conversation.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 204: "Conversation deleted successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(delete_response_fields) def delete(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -76,7 +127,32 @@ class ConversationApi(WebApiResource): return {"result": "success"}, 204 +@web_ns.route("/conversations//name") class ConversationRenameApi(WebApiResource): + @web_ns.doc("Rename Conversation") + @web_ns.doc(description="Rename a specific conversation with a custom name or auto-generate one.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + params={ + "name": {"description": "New conversation name", "type": "string", "required": False}, + "auto_generate": { + "description": "Auto-generate conversation name", + "type": "boolean", + "required": False, + "default": False, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Conversation renamed successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -96,11 +172,25 @@ class ConversationRenameApi(WebApiResource): raise NotFound("Conversation Not Exists.") +@web_ns.route("/conversations//pin") class ConversationPinApi(WebApiResource): pin_response_fields = { "result": fields.String, } + @web_ns.doc("Pin Conversation") + @web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 200: "Conversation pinned successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(pin_response_fields) def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -117,11 +207,25 @@ class ConversationPinApi(WebApiResource): return {"result": "success"} +@web_ns.route("/conversations//unpin") class ConversationUnPinApi(WebApiResource): unpin_response_fields = { "result": fields.String, } + @web_ns.doc("Unpin Conversation") + @web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 200: "Conversation unpinned successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(unpin_response_fields) def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -132,10 +236,3 @@ class ConversationUnPinApi(WebApiResource): WebConversationService.unpin(app_model, conversation_id, end_user) return {"result": "success"} - - -api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="web_conversation_name") -api.add_resource(ConversationListApi, "/conversations") -api.add_resource(ConversationApi, "/conversations/") -api.add_resource(ConversationPinApi, "/conversations//pin") -api.add_resource(ConversationUnPinApi, "/conversations//unpin") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 17e06e8856..26c0b133d9 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -4,7 +4,7 @@ from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import InternalServerError, NotFound -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError, @@ -38,6 +38,7 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) +@web_ns.route("/messages") class MessageListApi(WebApiResource): message_fields = { "id": fields.String, @@ -62,6 +63,30 @@ class MessageListApi(WebApiResource): "data": fields.List(fields.Nested(message_fields)), } + @web_ns.doc("Get Message List") + @web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.") + @web_ns.doc( + params={ + "conversation_id": {"description": "Conversation UUID", "type": "string", "required": True}, + "first_id": {"description": "First message ID for pagination", "type": "string", "required": False}, + "limit": { + "description": "Number of messages to return (1-100)", + "type": "integer", + "required": False, + "default": 20, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) @@ -84,11 +109,36 @@ class MessageListApi(WebApiResource): raise NotFound("First Message Not Exists.") +@web_ns.route("/messages//feedbacks") class MessageFeedbackApi(WebApiResource): feedback_response_fields = { "result": fields.String, } + @web_ns.doc("Create Message Feedback") + @web_ns.doc(description="Submit feedback (like/dislike) for a specific message.") + @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) + @web_ns.doc( + params={ + "rating": { + "description": "Feedback rating", + "type": "string", + "enum": ["like", "dislike"], + "required": False, + }, + "content": {"description": "Feedback content/comment", "type": "string", "required": False}, + } + ) + @web_ns.doc( + responses={ + 200: "Feedback submitted successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(feedback_response_fields) def post(self, app_model, end_user, message_id): message_id = str(message_id) @@ -112,7 +162,31 @@ class MessageFeedbackApi(WebApiResource): return {"result": "success"} +@web_ns.route("/messages//more-like-this") class MessageMoreLikeThisApi(WebApiResource): + @web_ns.doc("Generate More Like This") + @web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).") + @web_ns.doc( + params={ + "message_id": {"description": "Message UUID", "type": "string", "required": True}, + "response_mode": { + "description": "Response mode", + "type": "string", + "enum": ["blocking", "streaming"], + "required": True, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request - Not a completion app or feature disabled", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) def get(self, app_model, end_user, message_id): if app_model.mode != "completion": raise NotCompletionAppError() @@ -156,11 +230,25 @@ class MessageMoreLikeThisApi(WebApiResource): raise InternalServerError() +@web_ns.route("/messages//suggested-questions") class MessageSuggestedQuestionApi(WebApiResource): suggested_questions_response_fields = { "data": fields.List(fields.String), } + @web_ns.doc("Get Suggested Questions") + @web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).") + @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request - Not a chat app or feature disabled", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found or Conversation Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(suggested_questions_response_fields) def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) @@ -192,9 +280,3 @@ class MessageSuggestedQuestionApi(WebApiResource): raise InternalServerError() return {"data": questions} - - -api.add_resource(MessageListApi, "/messages") -api.add_resource(MessageFeedbackApi, "/messages//feedbacks") -api.add_resource(MessageMoreLikeThisApi, "/messages//more-like-this") -api.add_resource(MessageSuggestedQuestionApi, "/messages//suggested-questions") diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 7a9d24114e..96f09c8d3c 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import message_file_fields @@ -23,6 +23,7 @@ message_fields = { } +@web_ns.route("/saved-messages") class SavedMessageListApi(WebApiResource): saved_message_infinite_scroll_pagination_fields = { "limit": fields.Integer, @@ -34,6 +35,29 @@ class SavedMessageListApi(WebApiResource): "result": fields.String, } + @web_ns.doc("Get Saved Messages") + @web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.") + @web_ns.doc( + params={ + "last_id": {"description": "Last message ID for pagination", "type": "string", "required": False}, + "limit": { + "description": "Number of messages to return (1-100)", + "type": "integer", + "required": False, + "default": 20, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request - Not a completion app", + 401: "Unauthorized", + 403: "Forbidden", + 404: "App Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): if app_model.mode != "completion": @@ -46,6 +70,23 @@ class SavedMessageListApi(WebApiResource): return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) + @web_ns.doc("Save Message") + @web_ns.doc(description="Save a specific message for later reference.") + @web_ns.doc( + params={ + "message_id": {"description": "Message UUID to save", "type": "string", "required": True}, + } + ) + @web_ns.doc( + responses={ + 200: "Message saved successfully", + 400: "Bad Request - Not a completion app", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(post_response_fields) def post(self, app_model, end_user): if app_model.mode != "completion": @@ -63,11 +104,25 @@ class SavedMessageListApi(WebApiResource): return {"result": "success"} +@web_ns.route("/saved-messages/") class SavedMessageApi(WebApiResource): delete_response_fields = { "result": fields.String, } + @web_ns.doc("Delete Saved Message") + @web_ns.doc(description="Remove a message from saved messages.") + @web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 204: "Message removed successfully", + 400: "Bad Request - Not a completion app", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(delete_response_fields) def delete(self, app_model, end_user, message_id): message_id = str(message_id) @@ -78,7 +133,3 @@ class SavedMessageApi(WebApiResource): SavedMessageService.delete(app_model, end_user, message_id) return {"result": "success"}, 204 - - -api.add_resource(SavedMessageListApi, "/saved-messages") -api.add_resource(SavedMessageApi, "/saved-messages/") diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 91d67bf9d8..b01aaba357 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with from werkzeug.exceptions import Forbidden from configs import dify_config -from controllers.web import api +from controllers.web import web_ns from controllers.web.wraps import WebApiResource from extensions.ext_database import db from libs.helper import AppIconUrlField @@ -11,6 +11,7 @@ from models.model import Site from services.feature_service import FeatureService +@web_ns.route("/site") class AppSiteApi(WebApiResource): """Resource for app sites.""" @@ -53,9 +54,9 @@ class AppSiteApi(WebApiResource): "custom_config": fields.Raw(attribute="custom_config"), } - @api.doc("Get App Site Info") - @api.doc(description="Retrieve app site information and configuration.") - @api.doc( + @web_ns.doc("Get App Site Info") + @web_ns.doc(description="Retrieve app site information and configuration.") + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -82,9 +83,6 @@ class AppSiteApi(WebApiResource): return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) -api.add_resource(AppSiteApi, "/site") - - class AppSiteInfo: """Class to store site information.""" diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 3566cfae38..490dce8f05 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -3,7 +3,7 @@ import logging from flask_restx import reqparse from werkzeug.exceptions import InternalServerError -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( CompletionRequestError, NotWorkflowAppError, @@ -29,16 +29,17 @@ from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) +@web_ns.route("/workflows/run") class WorkflowRunApi(WebApiResource): - @api.doc("Run Workflow") - @api.doc(description="Execute a workflow with provided inputs and files.") - @api.doc( + @web_ns.doc("Run Workflow") + @web_ns.doc(description="Execute a workflow with provided inputs and files.") + @web_ns.doc( params={ "inputs": {"description": "Input variables for the workflow", "type": "object", "required": True}, "files": {"description": "Files to be processed by the workflow", "type": "array", "required": False}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -84,15 +85,16 @@ class WorkflowRunApi(WebApiResource): raise InternalServerError() +@web_ns.route("/workflows/tasks//stop") class WorkflowTaskStopApi(WebApiResource): - @api.doc("Stop Workflow Task") - @api.doc(description="Stop a running workflow task.") - @api.doc( + @web_ns.doc("Stop Workflow Task") + @web_ns.doc(description="Stop a running workflow task.") + @web_ns.doc( params={ "task_id": {"description": "Task ID to stop", "type": "string", "required": True}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -113,7 +115,3 @@ class WorkflowTaskStopApi(WebApiResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {"result": "success"} - - -api.add_resource(WorkflowRunApi, "/workflows/run") -api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") diff --git a/api/services/account_service.py b/api/services/account_service.py index f66c1aa677..f917959350 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -246,6 +246,8 @@ class AccountService: account.name = name if password: + valid_password(password) + # generate password salt salt = secrets.token_bytes(16) base64_salt = base64.b64encode(salt).decode() diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 6b5ac713e6..dac1fe643a 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -91,6 +91,28 @@ class TestAccountService: assert account.password is None assert account.password_salt is None + def test_create_account_password_invalid_new_password( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account create with invalid new password format. + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Test with too short password (assuming minimum length validation) + with pytest.raises(ValueError): # Password validation error + AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password="invalid_new_password", + ) + def test_create_account_registration_disabled(self, db_session_with_containers, mock_external_service_dependencies): """ Test account creation when registration is disabled. diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py new file mode 100644 index 0000000000..de81295100 --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -0,0 +1,1099 @@ +""" +Integration tests for create_segment_to_index_task using TestContainers. + +This module provides comprehensive testing for the create_segment_to_index_task +which handles asynchronous document segment indexing operations. +""" + +import time +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from faker import Faker + +from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.create_segment_to_index_task import create_segment_to_index_task + + +class TestCreateSegmentToIndexTask: + """Integration tests for create_segment_to_index_task using testcontainers.""" + + @pytest.fixture(autouse=True) + def cleanup_database(self, db_session_with_containers): + """Clean up database and Redis before each test to ensure isolation.""" + from extensions.ext_database import db + + # Clear all test data + db.session.query(DocumentSegment).delete() + db.session.query(Document).delete() + db.session.query(Dataset).delete() + db.session.query(TenantAccountJoin).delete() + db.session.query(Tenant).delete() + db.session.query(Account).delete() + db.session.commit() + + # Clear Redis cache + redis_client.flushdb() + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.create_segment_to_index_task.IndexProcessorFactory") as mock_factory, + ): + # Setup default mock returns + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + yield { + "index_processor_factory": mock_factory, + "index_processor": mock_processor, + } + + def _create_test_account_and_tenant(self, db_session_with_containers): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant + tenant = Tenant( + name=fake.company(), + status="normal", + plan="basic", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join with owner role + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_dataset_and_document(self, db_session_with_containers, tenant_id, account_id): + """ + Helper method to create a test dataset and document for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant_id: Tenant ID for the dataset + account_id: Account ID for the document + + Returns: + tuple: (dataset, document) - Created dataset and document instances + """ + fake = Faker() + + # Create dataset + dataset = Dataset( + name=fake.company(), + description=fake.text(max_nb_chars=100), + tenant_id=tenant_id, + data_source_type="upload_file", + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + created_by=account_id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + # Create document + document = Document( + name=fake.file_name(), + dataset_id=dataset.id, + tenant_id=tenant_id, + position=1, + data_source_type="upload_file", + batch="test_batch", + created_from="upload_file", + created_by=account_id, + enabled=True, + archived=False, + indexing_status="completed", + doc_form="qa_model", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + return dataset, document + + def _create_test_segment( + self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status="waiting" + ): + """ + Helper method to create a test document segment for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + dataset_id: Dataset ID for the segment + document_id: Document ID for the segment + tenant_id: Tenant ID for the segment + account_id: Account ID for the segment + status: Initial status of the segment + + Returns: + DocumentSegment: Created document segment instance + """ + fake = Faker() + + segment = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + position=1, + content=fake.text(max_nb_chars=500), + answer=fake.text(max_nb_chars=200), + word_count=len(fake.text(max_nb_chars=500).split()), + tokens=len(fake.text(max_nb_chars=500).split()) * 2, + keywords=["test", "document", "segment"], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status=status, + created_by=account_id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + return segment + + def test_create_segment_to_index_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful creation of segment to index. + + This test verifies: + - Segment status transitions from waiting to indexing to completed + - Index processor is called with correct parameters + - Segment metadata is properly updated + - Redis cache key is cleaned up + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify segment status changes + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + assert segment.error is None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + # Verify Redis cache cleanup + cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(cache_key) == 0 + + def test_create_segment_to_index_segment_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of non-existent segment ID. + + This test verifies: + - Task gracefully handles missing segment + - No exceptions are raised + - Database session is properly closed + """ + # Arrange: Use non-existent segment ID + non_existent_segment_id = str(uuid4()) + + # Act & Assert: Task should complete without error + result = create_segment_to_index_task(non_existent_segment_id) + assert result is None + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_invalid_status( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with invalid status. + + This test verifies: + - Task skips segments not in 'waiting' status + - No processing occurs for invalid status + - Database session is properly closed + """ + # Arrange: Create segment with invalid status + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="completed" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status unchanged + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is None + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_no_dataset(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test handling of segment without associated dataset. + + This test verifies: + - Task gracefully handles missing dataset + - Segment status remains unchanged + - No processing occurs + """ + # Arrange: Create segment with invalid dataset_id + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + invalid_dataset_id = str(uuid4()) + + # Create document with invalid dataset_id + document = Document( + name="test_doc", + dataset_id=invalid_dataset_id, + tenant_id=tenant.id, + position=1, + data_source_type="upload_file", + batch="test_batch", + created_from="upload_file", + created_by=account.id, + enabled=True, + archived=False, + indexing_status="completed", + doc_form="text_model", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, invalid_dataset_id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_no_document(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test handling of segment without associated document. + + This test verifies: + - Task gracefully handles missing document + - Segment status remains unchanged + - No processing occurs + """ + # Arrange: Create segment with invalid document_id + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, _ = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + invalid_document_id = str(uuid4()) + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, invalid_document_id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_document_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with disabled document. + + This test verifies: + - Task skips segments with disabled documents + - No processing occurs for disabled documents + - Segment status remains unchanged + """ + # Arrange: Create disabled document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Disable the document + document.enabled = False + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_document_archived( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with archived document. + + This test verifies: + - Task skips segments with archived documents + - No processing occurs for archived documents + - Segment status remains unchanged + """ + # Arrange: Create archived document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Archive the document + document.archived = True + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_document_indexing_incomplete( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with document that has incomplete indexing. + + This test verifies: + - Task skips segments with incomplete indexing documents + - No processing occurs for incomplete indexing + - Segment status remains unchanged + """ + # Arrange: Create document with incomplete indexing + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Set incomplete indexing status + document.indexing_status = "indexing" + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_processor_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of index processor exceptions. + + This test verifies: + - Task properly handles index processor failures + - Segment status is updated to error + - Segment is disabled with error information + - Redis cache is cleaned up despite errors + """ + # Arrange: Create test data and mock processor exception + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Mock processor to raise exception + mock_external_service_dependencies["index_processor"].load.side_effect = Exception("Processor failed") + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify error handling + db_session_with_containers.refresh(segment) + assert segment.status == "error" + assert segment.enabled is False + assert segment.disabled_at is not None + assert segment.error == "Processor failed" + + # Verify Redis cache cleanup still occurs + cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(cache_key) == 0 + + def test_create_segment_to_index_with_keywords( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with custom keywords. + + This test verifies: + - Task accepts and processes keywords parameter + - Keywords are properly passed through the task + - Indexing completes successfully with keywords + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + custom_keywords = ["custom", "keywords", "test"] + + # Act: Execute the task with keywords + create_segment_to_index_task(segment.id, keywords=custom_keywords) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + def test_create_segment_to_index_different_doc_forms( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with different document forms. + + This test verifies: + - Task works with various document forms + - Index processor factory receives correct doc_form + - Processing completes successfully for different forms + """ + # Arrange: Test different doc_forms + doc_forms = ["qa_model", "text_model", "web_model"] + + for doc_form in doc_forms: + # Create fresh test data for each form + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document( + db_session_with_containers, tenant.id, account.id + ) + + # Update document's doc_form for testing + document.doc_form = doc_form + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + + # Verify correct doc_form was passed to factory + mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form) + + def test_create_segment_to_index_performance_timing( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing performance and timing. + + This test verifies: + - Task execution time is reasonable + - Performance metrics are properly recorded + - No significant performance degradation + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task and measure time + start_time = time.time() + create_segment_to_index_task(segment.id) + end_time = time.time() + + # Assert: Verify performance + execution_time = end_time - start_time + assert execution_time < 5.0 # Should complete within 5 seconds + + # Verify successful completion + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + + def test_create_segment_to_index_concurrent_execution( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test concurrent execution of segment indexing tasks. + + This test verifies: + - Multiple tasks can run concurrently + - No race conditions occur + - All segments are processed correctly + """ + # Arrange: Create multiple test segments + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + segments = [] + for i in range(3): + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + segments.append(segment) + + # Act: Execute tasks concurrently (simulated) + segment_ids = [segment.id for segment in segments] + for segment_id in segment_ids: + create_segment_to_index_task(segment_id) + + # Assert: Verify all segments processed + for segment in segments: + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called for each segment + assert mock_external_service_dependencies["index_processor_factory"].call_count == 3 + + def test_create_segment_to_index_large_content( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with large content. + + This test verifies: + - Task handles large content segments + - Performance remains acceptable with large content + - No memory or processing issues occur + """ + # Arrange: Create segment with large content + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Generate large content (simulate large document) + large_content = "Large content " * 1000 # ~15KB content + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content=large_content, + answer="Large answer " * 100, + word_count=len(large_content.split()), + tokens=len(large_content.split()) * 2, + keywords=["large", "content", "test"], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status="waiting", + created_by=account.id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Act: Execute the task + start_time = time.time() + create_segment_to_index_task(segment.id) + end_time = time.time() + + # Assert: Verify successful processing + execution_time = end_time - start_time + assert execution_time < 10.0 # Should complete within 10 seconds + + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + def test_create_segment_to_index_redis_failure( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing when Redis operations fail. + + This test verifies: + - Task continues to work even if Redis fails + - Indexing completes successfully + - Redis errors don't affect core functionality + """ + # Arrange: Create test data and mock Redis failure + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Set up Redis cache key to simulate indexing in progress + cache_key = f"segment_{segment.id}_indexing" + redis_client.set(cache_key, "processing", ex=300) + + # Mock Redis to raise exception in finally block + with patch.object(redis_client, "delete", side_effect=Exception("Redis connection failed")): + # Act: Execute the task - Redis failure should not prevent completion + with pytest.raises(Exception) as exc_info: + create_segment_to_index_task(segment.id) + + # Verify the exception contains the expected Redis error message + assert "Redis connection failed" in str(exc_info.value) + + # Assert: Verify indexing still completed successfully despite Redis failure + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify Redis cache key still exists (since delete failed) + assert redis_client.exists(cache_key) == 1 + + def test_create_segment_to_index_database_transaction_rollback( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with database transaction handling. + + This test verifies: + - Database transactions are properly managed + - Rollback occurs on errors + - Data consistency is maintained + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Mock global database session to simulate transaction issues + from extensions.ext_database import db + + original_commit = db.session.commit + commit_called = False + + def mock_commit(): + nonlocal commit_called + if not commit_called: + commit_called = True + raise Exception("Database commit failed") + return original_commit() + + db.session.commit = mock_commit + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify error handling and rollback + db_session_with_containers.refresh(segment) + assert segment.status == "error" + assert segment.enabled is False + assert segment.disabled_at is not None + assert segment.error is not None + + # Restore original commit method + db.session.commit = original_commit + + def test_create_segment_to_index_metadata_validation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with metadata validation. + + This test verifies: + - Document metadata is properly constructed + - All required metadata fields are present + - Metadata is correctly passed to index processor + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + + # Verify index processor was called with correct metadata + mock_processor = mock_external_service_dependencies["index_processor"] + mock_processor.load.assert_called_once() + + # Get the call arguments to verify metadata structure + call_args = mock_processor.load.call_args + assert len(call_args[0]) == 2 # dataset and documents + + # Verify basic structure without deep object inspection + called_dataset = call_args[0][0] # first arg should be dataset + assert called_dataset is not None + + documents = call_args[0][1] # second arg should be list of documents + assert len(documents) == 1 + doc = documents[0] + assert doc is not None + + def test_create_segment_to_index_status_transition_flow( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test complete status transition flow during indexing. + + This test verifies: + - Status transitions: waiting -> indexing -> completed + - Timestamps are properly recorded at each stage + - No intermediate states are skipped + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Verify initial state + assert segment.status == "waiting" + assert segment.indexing_at is None + assert segment.completed_at is None + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify final state + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify timestamp ordering + assert segment.indexing_at <= segment.completed_at + + def test_create_segment_to_index_with_empty_content( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with empty or minimal content. + + This test verifies: + - Task handles empty content gracefully + - Indexing completes successfully with minimal content + - No errors occur with edge case content + """ + # Arrange: Create segment with minimal content + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content="", # Empty content + answer="", + word_count=0, + tokens=0, + keywords=[], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status="waiting", + created_by=account.id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + def test_create_segment_to_index_with_special_characters( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with special characters and unicode content. + + This test verifies: + - Task handles special characters correctly + - Unicode content is processed properly + - No encoding issues occur + """ + # Arrange: Create segment with special characters + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + special_content = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?`~" + unicode_content = "Unicode: δΈ­ζ–‡ζ΅‹θ―• πŸš€ 🌟 πŸ’»" + mixed_content = special_content + "\n" + unicode_content + + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content=mixed_content, + answer="Special answer: 🎯", + word_count=len(mixed_content.split()), + tokens=len(mixed_content.split()) * 2, + keywords=["special", "unicode", "test"], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status="waiting", + created_by=account.id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + def test_create_segment_to_index_with_long_keywords( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with long keyword lists. + + This test verifies: + - Task handles long keyword lists + - Keywords parameter is properly processed + - No performance issues with large keyword sets + """ + # Arrange: Create segment with long keywords + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Create long keyword list + long_keywords = [f"keyword_{i}" for i in range(100)] + + # Act: Execute the task with long keywords + create_segment_to_index_task(segment.id, keywords=long_keywords) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + def test_create_segment_to_index_tenant_isolation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with proper tenant isolation. + + This test verifies: + - Tasks are properly isolated by tenant + - No cross-tenant data access occurs + - Tenant boundaries are respected + """ + # Arrange: Create multiple tenants with segments + account1, tenant1 = self._create_test_account_and_tenant(db_session_with_containers) + account2, tenant2 = self._create_test_account_and_tenant(db_session_with_containers) + + dataset1, document1 = self._create_test_dataset_and_document( + db_session_with_containers, tenant1.id, account1.id + ) + dataset2, document2 = self._create_test_dataset_and_document( + db_session_with_containers, tenant2.id, account2.id + ) + + segment1 = self._create_test_segment( + db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status="waiting" + ) + segment2 = self._create_test_segment( + db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status="waiting" + ) + + # Act: Execute tasks for both tenants + create_segment_to_index_task(segment1.id) + create_segment_to_index_task(segment2.id) + + # Assert: Verify both segments processed independently + db_session_with_containers.refresh(segment1) + db_session_with_containers.refresh(segment2) + + assert segment1.status == "completed" + assert segment2.status == "completed" + assert segment1.tenant_id == tenant1.id + assert segment2.tenant_id == tenant2.id + assert segment1.tenant_id != segment2.tenant_id + + def test_create_segment_to_index_with_none_keywords( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with None keywords parameter. + + This test verifies: + - Task handles None keywords gracefully + - Default behavior works correctly + - No errors occur with None parameters + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task with None keywords + create_segment_to_index_task(segment.id, keywords=None) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + def test_create_segment_to_index_comprehensive_integration( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Comprehensive integration test covering multiple scenarios. + + This test verifies: + - Complete workflow from creation to completion + - All components work together correctly + - End-to-end functionality is maintained + - Performance and reliability under normal conditions + """ + # Arrange: Create comprehensive test scenario + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Create multiple segments with different characteristics + segments = [] + for i in range(5): + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + segments.append(segment) + + # Act: Process all segments + start_time = time.time() + for segment in segments: + create_segment_to_index_task(segment.id) + end_time = time.time() + + # Assert: Verify comprehensive success + total_time = end_time - start_time + assert total_time < 25.0 # Should complete all within 25 seconds + + # Verify all segments processed successfully + for segment in segments: + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + assert segment.error is None + + # Verify index processor was called for each segment + expected_calls = len(segments) + assert mock_external_service_dependencies["index_processor_factory"].call_count == expected_calls + + # Verify Redis cleanup for each segment + for segment in segments: + cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(cache_key) == 0 diff --git a/api/tests/unit_tests/libs/test_email_i18n.py b/api/tests/unit_tests/libs/test_email_i18n.py index b80c711cac..962a36fe03 100644 --- a/api/tests/unit_tests/libs/test_email_i18n.py +++ b/api/tests/unit_tests/libs/test_email_i18n.py @@ -246,6 +246,43 @@ class TestEmailI18nService: sent_email = mock_sender.sent_emails[0] assert sent_email["subject"] == "Reset Your Dify Password" + def test_subject_format_keyerror_fallback_path( + self, + mock_renderer: MockEmailRenderer, + mock_sender: MockEmailSender, + ): + """Trigger subject KeyError and cover except branch.""" + # Config with subject that references an unknown key (no {application_title} to avoid second format) + config = EmailI18nConfig( + templates={ + EmailType.INVITE_MEMBER: { + EmailLanguage.EN_US: EmailTemplate( + subject="Invite: {unknown_placeholder}", + template_path="invite_member_en.html", + branded_template_path="branded/invite_member_en.html", + ), + } + } + ) + branding_service = MockBrandingService(enabled=False) + service = EmailI18nService( + config=config, + renderer=mock_renderer, + branding_service=branding_service, + sender=mock_sender, + ) + + # Will raise KeyError on subject.format(**full_context), then hit except branch and skip fallback + service.send_email( + email_type=EmailType.INVITE_MEMBER, + language_code="en-US", + to="test@example.com", + ) + + assert len(mock_sender.sent_emails) == 1 + # Subject is left unformatted due to KeyError fallback path without application_title + assert mock_sender.sent_emails[0]["subject"] == "Invite: {unknown_placeholder}" + def test_send_change_email_old_phase( self, email_config: EmailI18nConfig, diff --git a/api/tests/unit_tests/libs/test_external_api.py b/api/tests/unit_tests/libs/test_external_api.py new file mode 100644 index 0000000000..a9edb913ea --- /dev/null +++ b/api/tests/unit_tests/libs/test_external_api.py @@ -0,0 +1,122 @@ +from flask import Blueprint, Flask +from flask_restx import Resource +from werkzeug.exceptions import BadRequest, Unauthorized + +from core.errors.error import AppInvokeQuotaExceededError +from libs.external_api import ExternalApi + + +def _create_api_app(): + app = Flask(__name__) + bp = Blueprint("t", __name__) + api = ExternalApi(bp) + + @api.route("/bad-request") + class Bad(Resource): # type: ignore + def get(self): # type: ignore + raise BadRequest("invalid input") + + @api.route("/unauth") + class Unauth(Resource): # type: ignore + def get(self): # type: ignore + raise Unauthorized("auth required") + + @api.route("/value-error") + class ValErr(Resource): # type: ignore + def get(self): # type: ignore + raise ValueError("boom") + + @api.route("/quota") + class Quota(Resource): # type: ignore + def get(self): # type: ignore + raise AppInvokeQuotaExceededError("quota exceeded") + + @api.route("/general") + class Gen(Resource): # type: ignore + def get(self): # type: ignore + raise RuntimeError("oops") + + # Note: We avoid altering default_mediatype to keep normal error paths + + # Special 400 message rewrite + @api.route("/json-empty") + class JsonEmpty(Resource): # type: ignore + def get(self): # type: ignore + e = BadRequest() + # Force the specific message the handler rewrites + e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)" + raise e + + # 400 mapping payload path + @api.route("/param-errors") + class ParamErrors(Resource): # type: ignore + def get(self): # type: ignore + e = BadRequest() + # Coerce a mapping description to trigger param error shaping + e.description = {"field": "is required"} # type: ignore[assignment] + raise e + + app.register_blueprint(bp, url_prefix="/api") + return app + + +def test_external_api_error_handlers_basic_paths(): + app = _create_api_app() + client = app.test_client() + + # 400 + res = client.get("/api/bad-request") + assert res.status_code == 400 + data = res.get_json() + assert data["code"] == "bad_request" + assert data["status"] == 400 + + # 401 + res = client.get("/api/unauth") + assert res.status_code == 401 + assert "WWW-Authenticate" in res.headers + + # 400 ValueError + res = client.get("/api/value-error") + assert res.status_code == 400 + assert res.get_json()["code"] == "invalid_param" + + # 500 general + res = client.get("/api/general") + assert res.status_code == 500 + assert res.get_json()["status"] == 500 + + +def test_external_api_json_message_and_bad_request_rewrite(): + app = _create_api_app() + client = app.test_client() + + # JSON empty special rewrite + res = client.get("/api/json-empty") + assert res.status_code == 400 + assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty." + + +def test_external_api_param_mapping_and_quota_and_exc_info_none(): + # Force exc_info() to return (None,None,None) only during request + import libs.external_api as ext + + orig_exc_info = ext.sys.exc_info + try: + ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment] + + app = _create_api_app() + client = app.test_client() + + # Param errors mapping payload path + res = client.get("/api/param-errors") + assert res.status_code == 400 + data = res.get_json() + assert data["code"] == "invalid_param" + assert data["params"] == "field" + + # Quota path β€” depending on Flask-RESTX internals it may be handled + res = client.get("/api/quota") + assert res.status_code in (400, 429) + finally: + ext.sys.exc_info = orig_exc_info # type: ignore[assignment] diff --git a/api/tests/unit_tests/libs/test_oauth_base.py b/api/tests/unit_tests/libs/test_oauth_base.py new file mode 100644 index 0000000000..3e0c235fff --- /dev/null +++ b/api/tests/unit_tests/libs/test_oauth_base.py @@ -0,0 +1,19 @@ +import pytest + +from libs.oauth import OAuth + + +def test_oauth_base_methods_raise_not_implemented(): + oauth = OAuth(client_id="id", client_secret="sec", redirect_uri="uri") + + with pytest.raises(NotImplementedError): + oauth.get_authorization_url() + + with pytest.raises(NotImplementedError): + oauth.get_access_token("code") + + with pytest.raises(NotImplementedError): + oauth.get_raw_user_info("token") + + with pytest.raises(NotImplementedError): + oauth._transform_user_info({}) # type: ignore[name-defined] diff --git a/api/tests/unit_tests/libs/test_sendgrid_client.py b/api/tests/unit_tests/libs/test_sendgrid_client.py new file mode 100644 index 0000000000..85744003c7 --- /dev/null +++ b/api/tests/unit_tests/libs/test_sendgrid_client.py @@ -0,0 +1,53 @@ +from unittest.mock import MagicMock, patch + +import pytest +from python_http_client.exceptions import UnauthorizedError + +from libs.sendgrid import SendGridClient + + +def _mail(to: str = "user@example.com") -> dict: + return {"to": to, "subject": "Hi", "html": "Hi"} + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_success(mock_client_cls: MagicMock): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + # nested attribute access: client.mail.send.post + mock_client.client.mail.send.post.return_value = MagicMock(status_code=202, body=b"", headers={}) + + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + sg.send(_mail()) + + mock_client_cls.assert_called_once() + mock_client.client.mail.send.post.assert_called_once() + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_missing_to_raises(mock_client_cls: MagicMock): + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + with pytest.raises(ValueError): + sg.send(_mail(to="")) + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_auth_errors_reraise(mock_client_cls: MagicMock): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + mock_client.client.mail.send.post.side_effect = UnauthorizedError(401, "Unauthorized", b"{}", {}) + + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + with pytest.raises(UnauthorizedError): + sg.send(_mail()) + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_timeout_reraise(mock_client_cls: MagicMock): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + mock_client.client.mail.send.post.side_effect = TimeoutError("timeout") + + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + with pytest.raises(TimeoutError): + sg.send(_mail()) diff --git a/api/tests/unit_tests/libs/test_smtp_client.py b/api/tests/unit_tests/libs/test_smtp_client.py new file mode 100644 index 0000000000..fcee01ca00 --- /dev/null +++ b/api/tests/unit_tests/libs/test_smtp_client.py @@ -0,0 +1,100 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from libs.smtp import SMTPClient + + +def _mail() -> dict: + return {"to": "user@example.com", "subject": "Hi", "html": "Hi"} + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_plain_success(mock_smtp_cls: MagicMock): + mock_smtp = MagicMock() + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com") + client.send(_mail()) + + mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10) + mock_smtp.sendmail.assert_called_once() + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock): + mock_smtp = MagicMock() + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient( + server="smtp.example.com", + port=587, + username="user", + password="pass", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=True, + ) + client.send(_mail()) + + mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10) + assert mock_smtp.ehlo.call_count == 2 + mock_smtp.starttls.assert_called_once() + mock_smtp.login.assert_called_once_with("user", "pass") + mock_smtp.sendmail.assert_called_once() + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP_SSL") +def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock): + # Cover SMTP_SSL branch and TimeoutError handling + mock_smtp = MagicMock() + mock_smtp.sendmail.side_effect = TimeoutError("timeout") + mock_smtp_ssl_cls.return_value = mock_smtp + + client = SMTPClient( + server="smtp.example.com", + port=465, + username="", + password="", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=False, + ) + with pytest.raises(TimeoutError): + client.send(_mail()) + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock): + mock_smtp = MagicMock() + mock_smtp.sendmail.side_effect = RuntimeError("oops") + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com") + with pytest.raises(RuntimeError): + client.send(_mail()) + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_smtplib_exception_in_login(mock_smtp_cls: MagicMock): + # Ensure we hit the specific SMTPException except branch + import smtplib + + mock_smtp = MagicMock() + mock_smtp.login.side_effect = smtplib.SMTPException("login-fail") + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient( + server="smtp.example.com", + port=25, + username="user", # non-empty to trigger login + password="pass", + _from="noreply@example.com", + ) + with pytest.raises(smtplib.SMTPException): + client.send(_mail()) + mock_smtp.quit.assert_called_once() diff --git a/web/app/components/base/markdown-blocks/think-block.tsx b/web/app/components/base/markdown-blocks/think-block.tsx index 46f992d758..a5813266f1 100644 --- a/web/app/components/base/markdown-blocks/think-block.tsx +++ b/web/app/components/base/markdown-blocks/think-block.tsx @@ -1,5 +1,6 @@ import React, { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import { useChatContext } from '../chat/chat/context' const hasEndThink = (children: any): boolean => { if (typeof children === 'string') @@ -35,6 +36,7 @@ const removeEndThink = (children: any): any => { } const useThinkTimer = (children: any) => { + const { isResponding } = useChatContext() const [startTime] = useState(Date.now()) const [elapsedTime, setElapsedTime] = useState(0) const [isComplete, setIsComplete] = useState(false) @@ -54,9 +56,9 @@ const useThinkTimer = (children: any) => { }, [startTime, isComplete]) useEffect(() => { - if (hasEndThink(children)) + if (hasEndThink(children) || !isResponding) setIsComplete(true) - }, [children]) + }, [children, isResponding]) return { elapsedTime, isComplete } }