mirror of https://github.com/langgenius/dify.git
add: trial api and trial table
This commit is contained in:
parent
345ac8333c
commit
781abe5f8f
|
|
@ -801,6 +801,16 @@ class MailConfig(BaseSettings):
|
|||
default=None,
|
||||
)
|
||||
|
||||
ENABLE_TRIAL_APP: bool = Field(
|
||||
description="Enable trial app",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENABLE_EXPLORE_BANNER: bool = Field(
|
||||
description="Enable explore banner",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -19,6 +19,16 @@ from .explore.message import (
|
|||
MessageMoreLikeThisApi,
|
||||
MessageSuggestedQuestionApi,
|
||||
)
|
||||
from .explore.trial import (
|
||||
AppApi,
|
||||
TrialAppParameterApi,
|
||||
TrialChatApi,
|
||||
TrialChatAudioApi,
|
||||
TrialChatTextApi,
|
||||
TrialCompletionApi,
|
||||
TrialMessageSuggestedQuestionApi,
|
||||
TrialSitApi,
|
||||
)
|
||||
from .explore.workflow import (
|
||||
InstalledAppWorkflowRunApi,
|
||||
InstalledAppWorkflowTaskStopApi,
|
||||
|
|
@ -127,10 +137,12 @@ from .datasets.rag_pipeline import (
|
|||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
banner,
|
||||
installed_app,
|
||||
parameter,
|
||||
recommended_app,
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
|
||||
# Import tag controllers
|
||||
|
|
@ -221,6 +233,26 @@ api.add_resource(
|
|||
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
|
||||
)
|
||||
|
||||
# Explore trial
|
||||
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
|
||||
|
||||
api.add_resource(
|
||||
TrialMessageSuggestedQuestionApi,
|
||||
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="trial_app_suggested_question",
|
||||
)
|
||||
|
||||
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
|
||||
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
|
||||
|
||||
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
|
||||
|
||||
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
|
||||
|
||||
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
|
||||
|
||||
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
|
||||
|
||||
api.add_namespace(console_ns)
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -235,6 +267,7 @@ __all__ = [
|
|||
"apikey",
|
||||
"app",
|
||||
"audio",
|
||||
"banner",
|
||||
"billing",
|
||||
"bp",
|
||||
"completion",
|
||||
|
|
@ -288,6 +321,7 @@ __all__ = [
|
|||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trial",
|
||||
"version",
|
||||
"website",
|
||||
"workflow",
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from constants.languages import supported_language
|
|||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
|
|
@ -61,6 +61,8 @@ class InsertExploreAppListApi(Resource):
|
|||
"language": fields.String(required=True, description="Language code"),
|
||||
"category": fields.String(required=True, description="App category"),
|
||||
"position": fields.Integer(required=True, description="Display position"),
|
||||
"can_trial": fields.Boolean(required=True, description="Can trial"),
|
||||
"trial_limit": fields.Integer(required=True, description="Trial limit"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -79,6 +81,8 @@ class InsertExploreAppListApi(Resource):
|
|||
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
parser.add_argument("can_trial", type=bool, required=True, nullable=False, location="json")
|
||||
parser.add_argument("trial_limit", type=int, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
||||
|
|
@ -115,6 +119,20 @@ class InsertExploreAppListApi(Resource):
|
|||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
if args["can_trial"]:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=args["app_id"],
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=args["trial_limit"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = args["trial_limit"]
|
||||
|
||||
app.is_public = True
|
||||
db.session.commit()
|
||||
|
|
@ -129,6 +147,20 @@ class InsertExploreAppListApi(Resource):
|
|||
recommended_app.category = args["category"]
|
||||
recommended_app.position = args["position"]
|
||||
|
||||
if args["can_trial"]:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=args["app_id"],
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=args["trial_limit"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = args["trial_limit"]
|
||||
app.is_public = True
|
||||
|
||||
db.session.commit()
|
||||
|
|
@ -174,7 +206,67 @@ class InsertExploreAppApi(Resource):
|
|||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
|
||||
trial_app = session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
|
||||
).scalar_one_or_none()
|
||||
if trial_app:
|
||||
session.delete(trial_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-banner")
|
||||
class InsertExploreBanner(Resource):
|
||||
@api.doc("insert_explore_banner")
|
||||
@api.doc(description="Insert an explore banner")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"InsertExploreBannerRequest",
|
||||
{
|
||||
"content": fields.String(required=True, description="Banner content"),
|
||||
"link": fields.String(required=True, description="Banner link"),
|
||||
"sort": fields.Integer(required=True, description="Banner sort"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Banner inserted successfully")
|
||||
@admin_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("link", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("sort", type=int, required=True, nullable=False, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
banner = ExporleBanner(
|
||||
content=args["content"],
|
||||
link=args["link"],
|
||||
sort=args["sort"],
|
||||
)
|
||||
db.session.add(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
|
||||
class DeleteExploreBanner(Resource):
|
||||
@api.doc("delete_explore_banner")
|
||||
@api.doc(description="Delete an explore banner")
|
||||
@api.response(204, "Banner deleted successfully")
|
||||
@admin_required
|
||||
@only_edition_cloud
|
||||
def delete(self, banner_id):
|
||||
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
|
||||
if not banner:
|
||||
raise NotFound(f"Banner '{banner_id}' is not found")
|
||||
|
||||
db.session.delete(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import explore_banner_enabled
|
||||
from extensions.ext_database import db
|
||||
from models.model import ExporleBanner
|
||||
|
||||
|
||||
class BannerApi(Resource):
|
||||
"""Resource for banner list."""
|
||||
|
||||
@explore_banner_enabled
|
||||
def get(self):
|
||||
"""Get banner list."""
|
||||
banners = (
|
||||
db.session.query(ExporleBanner).filter(ExporleBanner.status == "enabled").order_by(ExporleBanner.sort).all()
|
||||
)
|
||||
|
||||
# Convert banners to serializable format
|
||||
result = []
|
||||
for banner in banners:
|
||||
banner_data = {
|
||||
"content": banner.content, # Already parsed as JSON by SQLAlchemy
|
||||
"link": banner.link,
|
||||
"sort": banner.sort,
|
||||
"status": banner.status,
|
||||
"created_at": banner.created_at.isoformat() if banner.created_at else None,
|
||||
}
|
||||
result.append(banner_data)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(BannerApi, "/explore/banners")
|
||||
|
|
@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
|
|||
error_code = "access_denied"
|
||||
description = "App access denied."
|
||||
code = 403
|
||||
|
||||
|
||||
class TrialAppNotAllowed(BaseHTTPException):
|
||||
"""*403* `Trial App Not Allowed`
|
||||
|
||||
Raise if the user has reached the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_not_allowed"
|
||||
code = 403
|
||||
description = "the app is not allowed to be trial."
|
||||
|
||||
|
||||
class TrialAppLimitExceeded(BaseHTTPException):
|
||||
"""*403* `Trial App Limit Exceeded`
|
||||
|
||||
Raise if the user has exceeded the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_limit_exceeded"
|
||||
code = 403
|
||||
description = "The user has exceeded the trial app limit."
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ recommended_app_fields = {
|
|||
"category": fields.String,
|
||||
"position": fields.Integer,
|
||||
"is_listed": fields.Boolean,
|
||||
"can_trial": fields.Boolean,
|
||||
}
|
||||
|
||||
recommended_app_list_fields = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,349 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common import fields
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.explore.error import (
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.account import TenantStatus
|
||||
from models.model import AppMode, Site
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_service import AppService
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.errors.message import (
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialMessageSuggestedQuestionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def get(self, trial_app, message_id):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
|
||||
|
||||
class TrialChatAudioApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
|
||||
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||
)
|
||||
|
||||
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialSitApi(Resource):
|
||||
"""Resource for trial app sites."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
"""Retrieve app site info.
|
||||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
assert app_model.tenant
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
class TrialAppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model):
|
||||
"""Retrieve app parameters."""
|
||||
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.get_app(app_model)
|
||||
|
||||
return app_model
|
||||
|
|
@ -2,15 +2,16 @@ from collections.abc import Callable
|
|||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import abort
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError
|
||||
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models import InstalledApp
|
||||
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
|
@ -74,6 +75,59 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
|||
return decorator
|
||||
|
||||
|
||||
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[App, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
|
||||
|
||||
if trial_app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
app = trial_app.app
|
||||
|
||||
if app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
|
||||
account_trial_app_record = (
|
||||
db.session.query(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
|
||||
.first()
|
||||
)
|
||||
if account_trial_app_record:
|
||||
if account_trial_app_record.count >= trial_app.trial_limit:
|
||||
raise TrialAppLimitExceeded()
|
||||
|
||||
return view(app, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_trial_app:
|
||||
abort(403, "Trial app feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_explore_banner:
|
||||
abort(403, "Explore banner feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
class InstalledAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
|
|
@ -83,3 +137,13 @@ class InstalledAppResource(Resource):
|
|||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
||||
|
||||
class TrialAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
method_decorators = [
|
||||
trial_app_required,
|
||||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
"""add table explore banner and trial
|
||||
|
||||
Revision ID: 1b435d90db42
|
||||
Revises: cf7c38a32b2d
|
||||
Create Date: 2025-09-19 14:42:58.416649
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '1b435d90db42'
|
||||
down_revision = 'cf7c38a32b2d'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('account_trial_app_records',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('account_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('count', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'),
|
||||
sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record')
|
||||
)
|
||||
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
|
||||
batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False)
|
||||
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
|
||||
|
||||
op.create_table('exporle_banners',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('content', sa.JSON(), nullable=False),
|
||||
sa.Column('link', sa.String(length=255), nullable=False),
|
||||
sa.Column('sort', sa.Integer(), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
|
||||
)
|
||||
op.create_table('trial_apps',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('trial_limit', sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='trial_app_pkey'),
|
||||
sa.UniqueConstraint('app_id', name='unique_trail_app_id')
|
||||
)
|
||||
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
|
||||
batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False)
|
||||
batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.drop_column('credential_status')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credential_status', sa.VARCHAR(length=20), server_default=sa.text("'active'::character varying"), autoincrement=False, nullable=True))
|
||||
|
||||
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
|
||||
batch_op.drop_index('trial_app_tenant_id_idx')
|
||||
batch_op.drop_index('trial_app_app_id_idx')
|
||||
|
||||
op.drop_table('trial_apps')
|
||||
op.drop_table('exporle_banners')
|
||||
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
|
||||
batch_op.drop_index('account_trial_app_record_app_id_idx')
|
||||
batch_op.drop_index('account_trial_app_record_account_id_idx')
|
||||
|
||||
op.drop_table('account_trial_app_records')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -28,6 +28,7 @@ from .dataset import (
|
|||
)
|
||||
from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom
|
||||
from .model import (
|
||||
AccountTrialAppRecord,
|
||||
ApiRequest,
|
||||
ApiToken,
|
||||
App,
|
||||
|
|
@ -40,6 +41,7 @@ from .model import (
|
|||
DatasetRetrieverResource,
|
||||
DifySetup,
|
||||
EndUser,
|
||||
ExporleBanner,
|
||||
IconType,
|
||||
InstalledApp,
|
||||
Message,
|
||||
|
|
@ -54,6 +56,7 @@ from .model import (
|
|||
Tag,
|
||||
TagBinding,
|
||||
TraceAppConfig,
|
||||
TrialApp,
|
||||
UploadFile,
|
||||
)
|
||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
|
|
@ -98,6 +101,7 @@ __all__ = [
|
|||
"Account",
|
||||
"AccountIntegrate",
|
||||
"AccountStatus",
|
||||
"AccountTrialAppRecord",
|
||||
"ApiRequest",
|
||||
"ApiToken",
|
||||
"ApiToolProvider",
|
||||
|
|
@ -131,6 +135,7 @@ __all__ = [
|
|||
"DocumentSegment",
|
||||
"Embedding",
|
||||
"EndUser",
|
||||
"ExporleBanner",
|
||||
"ExternalKnowledgeApis",
|
||||
"ExternalKnowledgeBindings",
|
||||
"IconType",
|
||||
|
|
@ -168,6 +173,7 @@ __all__ = [
|
|||
"ToolLabelBinding",
|
||||
"ToolModelInvoke",
|
||||
"TraceAppConfig",
|
||||
"TrialApp",
|
||||
"UploadFile",
|
||||
"UserFrom",
|
||||
"Whitelist",
|
||||
|
|
|
|||
|
|
@ -581,6 +581,63 @@ class InstalledApp(Base):
|
|||
return tenant
|
||||
|
||||
|
||||
class TrialApp(Base):
|
||||
__tablename__ = "trial_apps"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
|
||||
sa.Index("trial_app_app_id_idx", "app_id"),
|
||||
sa.Index("trial_app_tenant_id_idx", "tenant_id"),
|
||||
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
|
||||
class AccountTrialAppRecord(Base):
|
||||
__tablename__ = "account_trial_app_records"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
|
||||
sa.Index("account_trial_app_record_account_id_idx", "account_id"),
|
||||
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
|
||||
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
|
||||
)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
user = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return user
|
||||
|
||||
|
||||
class ExporleBanner(Base):
|
||||
__tablename__ = "exporle_banners"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
content = mapped_column(sa.JSON, nullable=False)
|
||||
link = mapped_column(String(255), nullable=False)
|
||||
sort = mapped_column(sa.Integer, nullable=False)
|
||||
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class OAuthProviderApp(Base):
|
||||
"""
|
||||
Globally shared OAuth provider app information.
|
||||
|
|
|
|||
|
|
@ -160,6 +160,8 @@ class SystemFeatureModel(BaseModel):
|
|||
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
|
||||
enable_change_email: bool = True
|
||||
plugin_manager: PluginManagerModel = PluginManagerModel()
|
||||
enable_trial_app: bool = False
|
||||
enable_explore_banner: bool = False
|
||||
|
||||
|
||||
class FeatureService:
|
||||
|
|
@ -214,6 +216,8 @@ class FeatureService:
|
|||
system_features.is_allow_register = dify_config.ALLOW_REGISTER
|
||||
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
|
||||
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
|
||||
system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
|
||||
system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_env(cls, features: FeatureModel):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.model import AccountTrialAppRecord, TrialApp
|
||||
from services.feature_service import FeatureService
|
||||
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
|
||||
|
||||
|
||||
|
|
@ -20,6 +25,15 @@ class RecommendedAppService:
|
|||
)
|
||||
)
|
||||
|
||||
if FeatureService.get_system_features().enable_trial_app:
|
||||
apps = result["recommended_apps"]
|
||||
for app in apps:
|
||||
app_id = app["app_id"]
|
||||
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
|
||||
if trial_app_model:
|
||||
app["can_trial"] = True
|
||||
else:
|
||||
app["can_trial"] = False
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
|
|
@ -32,4 +46,27 @@ class RecommendedAppService:
|
|||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
|
||||
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
|
||||
if FeatureService.get_system_features().enable_trial_app:
|
||||
app_id = result["id"]
|
||||
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
|
||||
if trial_app_model:
|
||||
result["can_trial"] = True
|
||||
else:
|
||||
result["can_trial"] = False
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def add_trial_app_record(cls, app_id: str, account_id: str):
|
||||
"""
|
||||
Add trial app record.
|
||||
:param app_id: app id
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
account_trial_app_record = session.query(AccountTrialAppRecord).where(TrialApp.app_id == app_id).first()
|
||||
if account_trial_app_record:
|
||||
account_trial_app_record.count += 1
|
||||
session.commit()
|
||||
else:
|
||||
session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
|
||||
session.commit()
|
||||
|
|
|
|||
Loading…
Reference in New Issue