From e59cc3311db097c4d5d245dde37dd3b6e5c4a181 Mon Sep 17 00:00:00 2001 From: Yansong Zhang <916125788@qq.com> Date: Mon, 22 Sep 2025 10:44:08 +0800 Subject: [PATCH] add: trial api and trial table --- api/configs/feature/__init__.py | 10 + api/controllers/console/__init__.py | 34 ++ api/controllers/console/admin.py | 94 ++++- api/controllers/console/explore/banner.py | 34 ++ api/controllers/console/explore/error.py | 22 ++ .../console/explore/recommended_app.py | 1 + api/controllers/console/explore/trial.py | 349 ++++++++++++++++++ api/controllers/console/explore/wraps.py | 68 +++- ...db42_add_table_explore_banner_and_trial.py | 79 ++++ api/models/__init__.py | 6 + api/models/model.py | 57 +++ api/services/feature_service.py | 4 + api/services/recommended_app_service.py | 37 ++ 13 files changed, 792 insertions(+), 3 deletions(-) create mode 100644 api/controllers/console/explore/banner.py create mode 100644 api/controllers/console/explore/trial.py create mode 100644 api/migrations/versions/2025_09_19_1442-1b435d90db42_add_table_explore_banner_and_trial.py diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index b17f30210c..1ea8ce254d 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -801,6 +801,16 @@ class MailConfig(BaseSettings): default=None, ) + ENABLE_TRIAL_APP: bool = Field( + description="Enable trial app", + default=False, + ) + + ENABLE_EXPLORE_BANNER: bool = Field( + description="Enable explore banner", + default=False, + ) + class RagEtlConfig(BaseSettings): """ diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index ee02ff3937..0e8d9680e2 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -19,6 +19,16 @@ from .explore.message import ( MessageMoreLikeThisApi, MessageSuggestedQuestionApi, ) +from .explore.trial import ( + AppApi, + TrialAppParameterApi, + TrialChatApi, + TrialChatAudioApi, + TrialChatTextApi, + TrialCompletionApi, + TrialMessageSuggestedQuestionApi, + TrialSitApi, +) from .explore.workflow import ( InstalledAppWorkflowRunApi, InstalledAppWorkflowTaskStopApi, @@ -127,10 +137,12 @@ from .datasets.rag_pipeline import ( # Import explore controllers from .explore import ( + banner, installed_app, parameter, recommended_app, saved_message, + trial, ) # Import tag controllers @@ -221,6 +233,26 @@ api.add_resource( InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" ) +# Explore trial +api.add_resource(TrialChatApi, "/trial-apps//chat-messages", endpoint="trial_app_chat_completion") + +api.add_resource( + TrialMessageSuggestedQuestionApi, + "/trial-apps//messages//suggested-questions", + endpoint="trial_app_suggested_question", +) + +api.add_resource(TrialChatAudioApi, "/trial-apps//audio-to-text", endpoint="trial_app_audio") +api.add_resource(TrialChatTextApi, "/trial-apps//text-to-audio", endpoint="trial_app_text") + +api.add_resource(TrialCompletionApi, "/trial-apps//completion-messages", endpoint="trial_app_completion") + +api.add_resource(TrialSitApi, "/trial-apps//site") + +api.add_resource(TrialAppParameterApi, "/trial-apps//parameters", endpoint="trial_app_parameters") + +api.add_resource(AppApi, "/trial-apps/", endpoint="trial_app") + api.add_namespace(console_ns) __all__ = [ @@ -235,6 +267,7 @@ __all__ = [ "apikey", "app", "audio", + "banner", "billing", "bp", "completion", @@ -288,6 +321,7 @@ __all__ = [ "statistic", "tags", "tool_providers", + "trial", "version", "website", "workflow", diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 93f242ad28..ece0582a8b 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -15,7 +15,7 @@ from constants.languages import supported_language from controllers.console import api, console_ns from controllers.console.wraps import only_edition_cloud from extensions.ext_database import db -from models.model import App, InstalledApp, RecommendedApp +from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp def admin_required(view: Callable[P, R]): @@ -61,6 +61,8 @@ class InsertExploreAppListApi(Resource): "language": fields.String(required=True, description="Language code"), "category": fields.String(required=True, description="App category"), "position": fields.Integer(required=True, description="Display position"), + "can_trial": fields.Boolean(required=True, description="Can trial"), + "trial_limit": fields.Integer(required=True, description="Trial limit"), }, ) ) @@ -79,6 +81,8 @@ class InsertExploreAppListApi(Resource): parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json") parser.add_argument("category", type=str, required=True, nullable=False, location="json") parser.add_argument("position", type=int, required=True, nullable=False, location="json") + parser.add_argument("can_trial", type=bool, required=True, nullable=False, location="json") + parser.add_argument("trial_limit", type=int, required=True, nullable=False, location="json") args = parser.parse_args() app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() @@ -115,6 +119,20 @@ class InsertExploreAppListApi(Resource): ) db.session.add(recommended_app) + if args["can_trial"]: + trial_app = db.session.execute( + select(TrialApp).where(TrialApp.app_id == args["app_id"]) + ).scalar_one_or_none() + if not trial_app: + db.session.add( + TrialApp( + app_id=args["app_id"], + tenant_id=app.tenant_id, + trial_limit=args["trial_limit"], + ) + ) + else: + trial_app.trial_limit = args["trial_limit"] app.is_public = True db.session.commit() @@ -129,6 +147,20 @@ class InsertExploreAppListApi(Resource): recommended_app.category = args["category"] recommended_app.position = args["position"] + if args["can_trial"]: + trial_app = db.session.execute( + select(TrialApp).where(TrialApp.app_id == args["app_id"]) + ).scalar_one_or_none() + if not trial_app: + db.session.add( + TrialApp( + app_id=args["app_id"], + tenant_id=app.tenant_id, + trial_limit=args["trial_limit"], + ) + ) + else: + trial_app.trial_limit = args["trial_limit"] app.is_public = True db.session.commit() @@ -174,7 +206,67 @@ class InsertExploreAppApi(Resource): for installed_app in installed_apps: session.delete(installed_app) + trial_app = session.execute( + select(TrialApp).where(TrialApp.app_id == recommended_app.app_id) + ).scalar_one_or_none() + if trial_app: + session.delete(trial_app) + db.session.delete(recommended_app) db.session.commit() return {"result": "success"}, 204 + + +@console_ns.route("/admin/insert-explore-banner") +class InsertExploreBanner(Resource): + @api.doc("insert_explore_banner") + @api.doc(description="Insert an explore banner") + @api.expect( + api.model( + "InsertExploreBannerRequest", + { + "content": fields.String(required=True, description="Banner content"), + "link": fields.String(required=True, description="Banner link"), + "sort": fields.Integer(required=True, description="Banner sort"), + }, + ) + ) + @api.response(200, "Banner inserted successfully") + @admin_required + @only_edition_cloud + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + parser.add_argument("link", type=str, required=True, nullable=False, location="json") + parser.add_argument("sort", type=int, required=True, nullable=False, location="json") + + args = parser.parse_args() + + banner = ExporleBanner( + content=args["content"], + link=args["link"], + sort=args["sort"], + ) + db.session.add(banner) + db.session.commit() + + return {"result": "success"}, 200 + + +@console_ns.route("/admin/delete-explore-banner/") +class DeleteExploreBanner(Resource): + @api.doc("delete_explore_banner") + @api.doc(description="Delete an explore banner") + @api.response(204, "Banner deleted successfully") + @admin_required + @only_edition_cloud + def delete(self, banner_id): + banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none() + if not banner: + raise NotFound(f"Banner '{banner_id}' is not found") + + db.session.delete(banner) + db.session.commit() + + return {"result": "success"}, 204 diff --git a/api/controllers/console/explore/banner.py b/api/controllers/console/explore/banner.py new file mode 100644 index 0000000000..5e7aa1ec81 --- /dev/null +++ b/api/controllers/console/explore/banner.py @@ -0,0 +1,34 @@ +from flask_restx import Resource + +from controllers.console import api +from controllers.console.explore.wraps import explore_banner_enabled +from extensions.ext_database import db +from models.model import ExporleBanner + + +class BannerApi(Resource): + """Resource for banner list.""" + + @explore_banner_enabled + def get(self): + """Get banner list.""" + banners = ( + db.session.query(ExporleBanner).filter(ExporleBanner.status == "enabled").order_by(ExporleBanner.sort).all() + ) + + # Convert banners to serializable format + result = [] + for banner in banners: + banner_data = { + "content": banner.content, # Already parsed as JSON by SQLAlchemy + "link": banner.link, + "sort": banner.sort, + "status": banner.status, + "created_at": banner.created_at.isoformat() if banner.created_at else None, + } + result.append(banner_data) + + return result + + +api.add_resource(BannerApi, "/explore/banners") diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index 1e05ff4206..e96fa64f84 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException): error_code = "access_denied" description = "App access denied." code = 403 + + +class TrialAppNotAllowed(BaseHTTPException): + """*403* `Trial App Not Allowed` + + Raise if the user has reached the trial app limit. + """ + + error_code = "trial_app_not_allowed" + code = 403 + description = "the app is not allowed to be trial." + + +class TrialAppLimitExceeded(BaseHTTPException): + """*403* `Trial App Limit Exceeded` + + Raise if the user has exceeded the trial app limit. + """ + + error_code = "trial_app_limit_exceeded" + code = 403 + description = "The user has exceeded the trial app limit." diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 974222ddf7..1f26c75a43 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -27,6 +27,7 @@ recommended_app_fields = { "category": fields.String, "position": fields.Integer, "is_listed": fields.Boolean, + "can_trial": fields.Boolean, } recommended_app_list_fields = { diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py new file mode 100644 index 0000000000..c4976c0577 --- /dev/null +++ b/api/controllers/console/explore/trial.py @@ -0,0 +1,349 @@ +import logging + +from flask import request +from flask_restx import Resource, marshal_with, reqparse +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services +from controllers.common import fields +from controllers.console.app.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + ConversationCompletedError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) +from controllers.console.app.wraps import get_app_model +from controllers.console.explore.error import ( + AppSuggestedQuestionsAfterAnswerDisabledError, + NotChatAppError, + NotCompletionAppError, +) +from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict +from core.app.entities.app_invoke_entities import InvokeFrom +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from core.model_runtime.errors.invoke import InvokeError +from extensions.ext_database import db +from libs import helper +from libs.helper import uuid_value +from libs.login import current_user +from models import Account +from models.account import TenantStatus +from models.model import AppMode, Site +from services.app_generate_service import AppGenerateService +from services.app_service import AppService +from services.audio_service import AudioService +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) +from services.errors.conversation import ConversationNotExistsError +from services.errors.llm import InvokeRateLimitError +from services.errors.message import ( + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) +from services.message_service import MessageService +from services.recommended_app_service import RecommendedAppService + +logger = logging.getLogger(__name__) + + +class TrialChatApi(TrialAppResource): + @trial_feature_enable + def post(self, trial_app): + app_model = trial_app + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: + raise NotChatAppError() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") + args = parser.parse_args() + + args["auto_generate_name"] = False + + try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + response = AppGenerateService.generate( + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True + ) + RecommendedAppService.add_trial_app_record(app_model.id, current_user.id) + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + except ValueError as e: + raise e + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialMessageSuggestedQuestionApi(TrialAppResource): + @trial_feature_enable + def get(self, trial_app, message_id): + app_model = trial_app + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: + raise NotChatAppError() + + message_id = str(message_id) + + try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + questions = MessageService.get_suggested_questions_after_answer( + app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE + ) + except MessageNotExistsError: + raise NotFound("Message not found") + except ConversationNotExistsError: + raise NotFound("Conversation not found") + except SuggestedQuestionsAfterAnswerDisabledError: + raise AppSuggestedQuestionsAfterAnswerDisabledError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + return {"data": questions} + + +class TrialChatAudioApi(TrialAppResource): + @trial_feature_enable + def post(self, trial_app): + app_model = trial_app + + file = request.files["file"] + + try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None) + RecommendedAppService.add_trial_app_record(app_model.id, current_user.id) + return response + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except NoAudioUploadedServiceError: + raise NoAudioUploadedError() + except AudioTooLargeServiceError as e: + raise AudioTooLargeError(str(e)) + except UnsupportedAudioTypeServiceError: + raise UnsupportedAudioTypeError() + except ProviderNotSupportSpeechToTextServiceError: + raise ProviderNotSupportSpeechToTextError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception as e: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialChatTextApi(TrialAppResource): + @trial_feature_enable + def post(self, trial_app): + app_model = trial_app + try: + parser = reqparse.RequestParser() + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") + args = parser.parse_args() + + message_id = args.get("message_id", None) + text = args.get("text", None) + voice = args.get("voice", None) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) + RecommendedAppService.add_trial_app_record(app_model.id, current_user.id) + return response + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except NoAudioUploadedServiceError: + raise NoAudioUploadedError() + except AudioTooLargeServiceError as e: + raise AudioTooLargeError(str(e)) + except UnsupportedAudioTypeServiceError: + raise UnsupportedAudioTypeError() + except ProviderNotSupportSpeechToTextServiceError: + raise ProviderNotSupportSpeechToTextError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception as e: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialCompletionApi(TrialAppResource): + @trial_feature_enable + def post(self, trial_app): + app_model = trial_app + if app_model.mode != "completion": + raise NotCompletionAppError() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") + args = parser.parse_args() + + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False + + try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + response = AppGenerateService.generate( + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming + ) + + RecommendedAppService.add_trial_app_record(app_model.id, current_user.id) + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialSitApi(Resource): + """Resource for trial app sites.""" + + @trial_feature_enable + @get_app_model + def get(self, app_model): + """Retrieve app site info. + + Returns the site configuration for the application including theme, icons, and text. + """ + site = db.session.query(Site).where(Site.app_id == app_model.id).first() + + if not site: + raise Forbidden() + + assert app_model.tenant + if app_model.tenant.status == TenantStatus.ARCHIVE: + raise Forbidden() + + return site + + +class TrialAppParameterApi(Resource): + """Resource for app variables.""" + + @trial_feature_enable + @get_app_model + @marshal_with(fields.parameters_fields) + def get(self, app_model): + """Retrieve app parameters.""" + + if app_model is None: + raise AppUnavailableError() + + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow = app_model.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form(to_old_structure=True) + else: + app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get("user_input_form", []) + + return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + + +class AppApi(Resource): + @trial_feature_enable + @get_app_model + def get(self, app_model): + """Get app detail""" + + app_service = AppService() + app_model = app_service.get_app(app_model) + + return app_model diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 3a8ba64a03..be11f6f725 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -2,15 +2,16 @@ from collections.abc import Callable from functools import wraps from typing import Concatenate, ParamSpec, TypeVar +from flask import abort from flask_login import current_user from flask_restx import Resource from werkzeug.exceptions import NotFound -from controllers.console.explore.error import AppAccessDeniedError +from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.login import login_required -from models import InstalledApp +from models import AccountTrialAppRecord, App, InstalledApp, TrialApp from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -74,6 +75,59 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | return decorator +def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None): + def decorator(view: Callable[Concatenate[App, P], R]): + @wraps(view) + def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs): + trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first() + + if trial_app is None: + raise TrialAppNotAllowed() + app = trial_app.app + + if app is None: + raise TrialAppNotAllowed() + + account_trial_app_record = ( + db.session.query(AccountTrialAppRecord) + .where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id) + .first() + ) + if account_trial_app_record: + if account_trial_app_record.count >= trial_app.trial_limit: + raise TrialAppLimitExceeded() + + return view(app, *args, **kwargs) + + return decorated + + if view: + return decorator(view) + return decorator + + +def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]: + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if not features.enable_trial_app: + abort(403, "Trial app feature is not enabled.") + return view(*args, **kwargs) + + return decorated + + +def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]: + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if not features.enable_explore_banner: + abort(403, "Explore banner feature is not enabled.") + return view(*args, **kwargs) + + return decorated + + class InstalledAppResource(Resource): # must be reversed if there are multiple decorators @@ -83,3 +137,13 @@ class InstalledAppResource(Resource): account_initialization_required, login_required, ] + + +class TrialAppResource(Resource): + # must be reversed if there are multiple decorators + + method_decorators = [ + trial_app_required, + account_initialization_required, + login_required, + ] diff --git a/api/migrations/versions/2025_09_19_1442-1b435d90db42_add_table_explore_banner_and_trial.py b/api/migrations/versions/2025_09_19_1442-1b435d90db42_add_table_explore_banner_and_trial.py new file mode 100644 index 0000000000..6d20273c5d --- /dev/null +++ b/api/migrations/versions/2025_09_19_1442-1b435d90db42_add_table_explore_banner_and_trial.py @@ -0,0 +1,79 @@ +"""add table explore banner and trial + +Revision ID: 1b435d90db42 +Revises: cf7c38a32b2d +Create Date: 2025-09-19 14:42:58.416649 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1b435d90db42' +down_revision = 'cf7c38a32b2d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('account_trial_app_records', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('account_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('count', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'), + sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record') + ) + with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op: + batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False) + batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False) + + op.create_table('exporle_banners', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('content', sa.JSON(), nullable=False), + sa.Column('link', sa.String(length=255), nullable=False), + sa.Column('sort', sa.Integer(), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey') + ) + op.create_table('trial_apps', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('trial_limit', sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint('id', name='trial_app_pkey'), + sa.UniqueConstraint('app_id', name='unique_trail_app_id') + ) + with op.batch_alter_table('trial_apps', schema=None) as batch_op: + batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False) + batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False) + + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.drop_column('credential_status') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('credential_status', sa.VARCHAR(length=20), server_default=sa.text("'active'::character varying"), autoincrement=False, nullable=True)) + + with op.batch_alter_table('trial_apps', schema=None) as batch_op: + batch_op.drop_index('trial_app_tenant_id_idx') + batch_op.drop_index('trial_app_app_id_idx') + + op.drop_table('trial_apps') + op.drop_table('exporle_banners') + with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op: + batch_op.drop_index('account_trial_app_record_app_id_idx') + batch_op.drop_index('account_trial_app_record_account_id_idx') + + op.drop_table('account_trial_app_records') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 779484283f..6adc8e56b6 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -28,6 +28,7 @@ from .dataset import ( ) from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom from .model import ( + AccountTrialAppRecord, ApiRequest, ApiToken, App, @@ -40,6 +41,7 @@ from .model import ( DatasetRetrieverResource, DifySetup, EndUser, + ExporleBanner, IconType, InstalledApp, Message, @@ -54,6 +56,7 @@ from .model import ( Tag, TagBinding, TraceAppConfig, + TrialApp, UploadFile, ) from .oauth import DatasourceOauthParamConfig, DatasourceProvider @@ -98,6 +101,7 @@ __all__ = [ "Account", "AccountIntegrate", "AccountStatus", + "AccountTrialAppRecord", "ApiRequest", "ApiToken", "ApiToolProvider", @@ -131,6 +135,7 @@ __all__ = [ "DocumentSegment", "Embedding", "EndUser", + "ExporleBanner", "ExternalKnowledgeApis", "ExternalKnowledgeBindings", "IconType", @@ -168,6 +173,7 @@ __all__ = [ "ToolLabelBinding", "ToolModelInvoke", "TraceAppConfig", + "TrialApp", "UploadFile", "UserFrom", "Whitelist", diff --git a/api/models/model.py b/api/models/model.py index 9bcb81b41b..7ff5170f1f 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -581,6 +581,63 @@ class InstalledApp(Base): return tenant +class TrialApp(Base): + __tablename__ = "trial_apps" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="trial_app_pkey"), + sa.Index("trial_app_app_id_idx", "app_id"), + sa.Index("trial_app_tenant_id_idx", "tenant_id"), + sa.UniqueConstraint("app_id", name="unique_trail_app_id"), + ) + + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + tenant_id = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + trial_limit = mapped_column(sa.Integer, nullable=False, default=3) + + @property + def app(self) -> App | None: + app = db.session.query(App).where(App.id == self.app_id).first() + return app + + +class AccountTrialAppRecord(Base): + __tablename__ = "account_trial_app_records" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"), + sa.Index("account_trial_app_record_account_id_idx", "account_id"), + sa.Index("account_trial_app_record_app_id_idx", "app_id"), + sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"), + ) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + account_id = mapped_column(StringUUID, nullable=False) + app_id = mapped_column(StringUUID, nullable=False) + count = mapped_column(sa.Integer, nullable=False, default=0) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def app(self) -> App | None: + app = db.session.query(App).where(App.id == self.app_id).first() + return app + + @property + def user(self) -> Account | None: + user = db.session.query(Account).where(Account.id == self.account_id).first() + return user + + +class ExporleBanner(Base): + __tablename__ = "exporle_banners" + __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + content = mapped_column(sa.JSON, nullable=False) + link = mapped_column(String(255), nullable=False) + sort = mapped_column(sa.Integer, nullable=False) + status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying")) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + + class OAuthProviderApp(Base): """ Globally shared OAuth provider app information. diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 19d96cb972..74df593782 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -160,6 +160,8 @@ class SystemFeatureModel(BaseModel): plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() enable_change_email: bool = True plugin_manager: PluginManagerModel = PluginManagerModel() + enable_trial_app: bool = False + enable_explore_banner: bool = False class FeatureService: @@ -214,6 +216,8 @@ class FeatureService: system_features.is_allow_register = dify_config.ALLOW_REGISTER system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" + system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP + system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER @classmethod def _fulfill_params_from_env(cls, features: FeatureModel): diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 544383a106..b0c31e272b 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,4 +1,9 @@ +from sqlalchemy.orm import Session + from configs import dify_config +from extensions.ext_database import db +from models.model import AccountTrialAppRecord, TrialApp +from services.feature_service import FeatureService from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory @@ -20,6 +25,15 @@ class RecommendedAppService: ) ) + if FeatureService.get_system_features().enable_trial_app: + apps = result["recommended_apps"] + for app in apps: + app_id = app["app_id"] + trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first() + if trial_app_model: + app["can_trial"] = True + else: + app["can_trial"] = False return result @classmethod @@ -32,4 +46,27 @@ class RecommendedAppService: mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() result: dict = retrieval_instance.get_recommend_app_detail(app_id) + if FeatureService.get_system_features().enable_trial_app: + app_id = result["id"] + trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first() + if trial_app_model: + result["can_trial"] = True + else: + result["can_trial"] = False return result + + @classmethod + def add_trial_app_record(cls, app_id: str, account_id: str): + """ + Add trial app record. + :param app_id: app id + :return: + """ + with Session(db.engine) as session: + account_trial_app_record = session.query(AccountTrialAppRecord).where(TrialApp.app_id == app_id).first() + if account_trial_app_record: + account_trial_app_record.count += 1 + session.commit() + else: + session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id)) + session.commit()