mirror of https://github.com/langgenius/dify.git
refactor: port reqparse to Pydantic model (#28949)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
6325dcf8aa
commit
7396eba1af
|
|
@ -3,7 +3,8 @@ from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
@ -18,6 +19,30 @@ from extensions.ext_database import db
|
||||||
from libs.token import extract_access_token
|
from libs.token import extract_access_token
|
||||||
from models.model import App, InstalledApp, RecommendedApp
|
from models.model import App, InstalledApp, RecommendedApp
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class InsertExploreAppPayload(BaseModel):
|
||||||
|
app_id: str = Field(...)
|
||||||
|
desc: str | None = None
|
||||||
|
copyright: str | None = None
|
||||||
|
privacy_policy: str | None = None
|
||||||
|
custom_disclaimer: str | None = None
|
||||||
|
language: str = Field(...)
|
||||||
|
category: str = Field(...)
|
||||||
|
position: int = Field(...)
|
||||||
|
|
||||||
|
@field_validator("language")
|
||||||
|
@classmethod
|
||||||
|
def validate_language(cls, value: str) -> str:
|
||||||
|
return supported_language(value)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
InsertExploreAppPayload.__name__,
|
||||||
|
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def admin_required(view: Callable[P, R]):
|
def admin_required(view: Callable[P, R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
|
|
@ -40,59 +65,34 @@ def admin_required(view: Callable[P, R]):
|
||||||
class InsertExploreAppListApi(Resource):
|
class InsertExploreAppListApi(Resource):
|
||||||
@console_ns.doc("insert_explore_app")
|
@console_ns.doc("insert_explore_app")
|
||||||
@console_ns.doc(description="Insert or update an app in the explore list")
|
@console_ns.doc(description="Insert or update an app in the explore list")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"InsertExploreAppRequest",
|
|
||||||
{
|
|
||||||
"app_id": fields.String(required=True, description="Application ID"),
|
|
||||||
"desc": fields.String(description="App description"),
|
|
||||||
"copyright": fields.String(description="Copyright information"),
|
|
||||||
"privacy_policy": fields.String(description="Privacy policy"),
|
|
||||||
"custom_disclaimer": fields.String(description="Custom disclaimer"),
|
|
||||||
"language": fields.String(required=True, description="Language code"),
|
|
||||||
"category": fields.String(required=True, description="App category"),
|
|
||||||
"position": fields.Integer(required=True, description="Display position"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "App updated successfully")
|
@console_ns.response(200, "App updated successfully")
|
||||||
@console_ns.response(201, "App inserted successfully")
|
@console_ns.response(201, "App inserted successfully")
|
||||||
@console_ns.response(404, "App not found")
|
@console_ns.response(404, "App not found")
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
@admin_required
|
@admin_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
payload = InsertExploreAppPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("desc", type=str, location="json")
|
|
||||||
.add_argument("copyright", type=str, location="json")
|
|
||||||
.add_argument("privacy_policy", type=str, location="json")
|
|
||||||
.add_argument("custom_disclaimer", type=str, location="json")
|
|
||||||
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("position", type=int, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
|
||||||
if not app:
|
if not app:
|
||||||
raise NotFound(f"App '{args['app_id']}' is not found")
|
raise NotFound(f"App '{payload.app_id}' is not found")
|
||||||
|
|
||||||
site = app.site
|
site = app.site
|
||||||
if not site:
|
if not site:
|
||||||
desc = args["desc"] or ""
|
desc = payload.desc or ""
|
||||||
copy_right = args["copyright"] or ""
|
copy_right = payload.copyright or ""
|
||||||
privacy_policy = args["privacy_policy"] or ""
|
privacy_policy = payload.privacy_policy or ""
|
||||||
custom_disclaimer = args["custom_disclaimer"] or ""
|
custom_disclaimer = payload.custom_disclaimer or ""
|
||||||
else:
|
else:
|
||||||
desc = site.description or args["desc"] or ""
|
desc = site.description or payload.desc or ""
|
||||||
copy_right = site.copyright or args["copyright"] or ""
|
copy_right = site.copyright or payload.copyright or ""
|
||||||
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
|
||||||
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
recommended_app = session.execute(
|
recommended_app = session.execute(
|
||||||
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
|
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
|
||||||
).scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
|
|
||||||
if not recommended_app:
|
if not recommended_app:
|
||||||
|
|
@ -102,9 +102,9 @@ class InsertExploreAppListApi(Resource):
|
||||||
copyright=copy_right,
|
copyright=copy_right,
|
||||||
privacy_policy=privacy_policy,
|
privacy_policy=privacy_policy,
|
||||||
custom_disclaimer=custom_disclaimer,
|
custom_disclaimer=custom_disclaimer,
|
||||||
language=args["language"],
|
language=payload.language,
|
||||||
category=args["category"],
|
category=payload.category,
|
||||||
position=args["position"],
|
position=payload.position,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.add(recommended_app)
|
db.session.add(recommended_app)
|
||||||
|
|
@ -118,9 +118,9 @@ class InsertExploreAppListApi(Resource):
|
||||||
recommended_app.copyright = copy_right
|
recommended_app.copyright = copy_right
|
||||||
recommended_app.privacy_policy = privacy_policy
|
recommended_app.privacy_policy = privacy_policy
|
||||||
recommended_app.custom_disclaimer = custom_disclaimer
|
recommended_app.custom_disclaimer = custom_disclaimer
|
||||||
recommended_app.language = args["language"]
|
recommended_app.language = payload.language
|
||||||
recommended_app.category = args["category"]
|
recommended_app.category = payload.category
|
||||||
recommended_app.position = args["position"]
|
recommended_app.position = payload.position
|
||||||
|
|
||||||
app.is_public = True
|
app.is_public = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
|
|
@ -8,10 +10,21 @@ from libs.login import login_required
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.agent_service import AgentService
|
from services.agent_service import AgentService
|
||||||
|
|
||||||
parser = (
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
|
|
||||||
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
|
class AgentLogQuery(BaseModel):
|
||||||
|
message_id: str = Field(..., description="Message UUID")
|
||||||
|
conversation_id: str = Field(..., description="Conversation UUID")
|
||||||
|
|
||||||
|
@field_validator("message_id", "conversation_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_uuid(cls, value: str) -> str:
|
||||||
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,7 +33,7 @@ class AgentLogApi(Resource):
|
||||||
@console_ns.doc("get_agent_logs")
|
@console_ns.doc("get_agent_logs")
|
||||||
@console_ns.doc(description="Get agent execution logs for an application")
|
@console_ns.doc(description="Get agent execution logs for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[AgentLogQuery.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
|
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
|
||||||
)
|
)
|
||||||
|
|
@ -31,6 +44,6 @@ class AgentLogApi(Resource):
|
||||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
"""Get agent logs"""
|
"""Get agent logs"""
|
||||||
args = parser.parse_args()
|
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -21,22 +22,79 @@ from libs.helper import uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.annotation_service import AppAnnotationService
|
from services.annotation_service import AppAnnotationService
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationReplyPayload(BaseModel):
|
||||||
|
score_threshold: float = Field(..., description="Score threshold for annotation matching")
|
||||||
|
embedding_provider_name: str = Field(..., description="Embedding provider name")
|
||||||
|
embedding_model_name: str = Field(..., description="Embedding model name")
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationSettingUpdatePayload(BaseModel):
|
||||||
|
score_threshold: float = Field(..., description="Score threshold")
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationListQuery(BaseModel):
|
||||||
|
page: int = Field(default=1, ge=1, description="Page number")
|
||||||
|
limit: int = Field(default=20, ge=1, description="Page size")
|
||||||
|
keyword: str = Field(default="", description="Search keyword")
|
||||||
|
|
||||||
|
|
||||||
|
class CreateAnnotationPayload(BaseModel):
|
||||||
|
message_id: str | None = Field(default=None, description="Message ID")
|
||||||
|
question: str | None = Field(default=None, description="Question text")
|
||||||
|
answer: str | None = Field(default=None, description="Answer text")
|
||||||
|
content: str | None = Field(default=None, description="Content text")
|
||||||
|
annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data")
|
||||||
|
|
||||||
|
@field_validator("message_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_message_id(cls, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateAnnotationPayload(BaseModel):
|
||||||
|
question: str | None = None
|
||||||
|
answer: str | None = None
|
||||||
|
content: str | None = None
|
||||||
|
annotation_reply: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationReplyStatusQuery(BaseModel):
|
||||||
|
action: Literal["enable", "disable"]
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationFilePayload(BaseModel):
|
||||||
|
message_id: str = Field(..., description="Message ID")
|
||||||
|
|
||||||
|
@field_validator("message_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_message_id(cls, value: str) -> str:
|
||||||
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
def reg(model: type[BaseModel]) -> None:
|
||||||
|
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
|
reg(AnnotationReplyPayload)
|
||||||
|
reg(AnnotationSettingUpdatePayload)
|
||||||
|
reg(AnnotationListQuery)
|
||||||
|
reg(CreateAnnotationPayload)
|
||||||
|
reg(UpdateAnnotationPayload)
|
||||||
|
reg(AnnotationReplyStatusQuery)
|
||||||
|
reg(AnnotationFilePayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||||
class AnnotationReplyActionApi(Resource):
|
class AnnotationReplyActionApi(Resource):
|
||||||
@console_ns.doc("annotation_reply_action")
|
@console_ns.doc("annotation_reply_action")
|
||||||
@console_ns.doc(description="Enable or disable annotation reply for an app")
|
@console_ns.doc(description="Enable or disable annotation reply for an app")
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"AnnotationReplyActionRequest",
|
|
||||||
{
|
|
||||||
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
|
|
||||||
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
|
|
||||||
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Action completed successfully")
|
@console_ns.response(200, "Action completed successfully")
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -46,15 +104,9 @@ class AnnotationReplyActionApi(Resource):
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
parser = (
|
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("score_threshold", required=True, type=float, location="json")
|
|
||||||
.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
|
||||||
.add_argument("embedding_model_name", required=True, type=str, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
if action == "enable":
|
if action == "enable":
|
||||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||||
elif action == "disable":
|
elif action == "disable":
|
||||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||||
return result, 200
|
return result, 200
|
||||||
|
|
@ -82,16 +134,7 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||||
@console_ns.doc("update_annotation_setting")
|
@console_ns.doc("update_annotation_setting")
|
||||||
@console_ns.doc(description="Update annotation settings for an app")
|
@console_ns.doc(description="Update annotation settings for an app")
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"AnnotationSettingUpdateRequest",
|
|
||||||
{
|
|
||||||
"score_threshold": fields.Float(required=True, description="Score threshold"),
|
|
||||||
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
|
|
||||||
"embedding_model_name": fields.String(required=True, description="Embedding model"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Settings updated successfully")
|
@console_ns.response(200, "Settings updated successfully")
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -102,10 +145,9 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_setting_id = str(annotation_setting_id)
|
annotation_setting_id = str(annotation_setting_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
|
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
|
||||||
return result, 200
|
return result, 200
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -142,12 +184,7 @@ class AnnotationApi(Resource):
|
||||||
@console_ns.doc("list_annotations")
|
@console_ns.doc("list_annotations")
|
||||||
@console_ns.doc(description="Get annotations for an app with pagination")
|
@console_ns.doc(description="Get annotations for an app with pagination")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AnnotationListQuery.__name__])
|
||||||
console_ns.parser()
|
|
||||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
|
||||||
.add_argument("limit", type=int, location="args", default=20, help="Page size")
|
|
||||||
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Annotations retrieved successfully")
|
@console_ns.response(200, "Annotations retrieved successfully")
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -155,9 +192,10 @@ class AnnotationApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
page = request.args.get("page", default=1, type=int)
|
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
page = args.page
|
||||||
keyword = request.args.get("keyword", default="", type=str)
|
limit = args.limit
|
||||||
|
keyword = args.keyword
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||||
|
|
@ -173,18 +211,7 @@ class AnnotationApi(Resource):
|
||||||
@console_ns.doc("create_annotation")
|
@console_ns.doc("create_annotation")
|
||||||
@console_ns.doc(description="Create a new annotation for an app")
|
@console_ns.doc(description="Create a new annotation for an app")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"CreateAnnotationRequest",
|
|
||||||
{
|
|
||||||
"message_id": fields.String(description="Message ID (optional)"),
|
|
||||||
"question": fields.String(description="Question text (required when message_id not provided)"),
|
|
||||||
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
|
|
||||||
"content": fields.String(description="Content text (use 'answer' or 'content')"),
|
|
||||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -195,16 +222,9 @@ class AnnotationApi(Resource):
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
parser = (
|
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
data = args.model_dump(exclude_none=True)
|
||||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
|
||||||
.add_argument("question", required=False, type=str, location="json")
|
|
||||||
.add_argument("answer", required=False, type=str, location="json")
|
|
||||||
.add_argument("content", required=False, type=str, location="json")
|
|
||||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
|
||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -256,13 +276,6 @@ class AnnotationExportApi(Resource):
|
||||||
return response, 200
|
return response, 200
|
||||||
|
|
||||||
|
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("question", required=True, type=str, location="json")
|
|
||||||
.add_argument("answer", required=True, type=str, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||||
class AnnotationUpdateDeleteApi(Resource):
|
class AnnotationUpdateDeleteApi(Resource):
|
||||||
@console_ns.doc("update_delete_annotation")
|
@console_ns.doc("update_delete_annotation")
|
||||||
|
|
@ -271,7 +284,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||||
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
||||||
@console_ns.response(204, "Annotation deleted successfully")
|
@console_ns.response(204, "Annotation deleted successfully")
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -281,8 +294,10 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||||
def post(self, app_id, annotation_id):
|
def post(self, app_id, annotation_id):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_id = str(annotation_id)
|
annotation_id = str(annotation_id)
|
||||||
args = parser.parse_args()
|
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
|
||||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||||
|
args.model_dump(exclude_none=True), app_id, annotation_id
|
||||||
|
)
|
||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
|
|
@ -35,23 +36,29 @@ app_import_check_dependencies_model = console_ns.model(
|
||||||
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
|
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
|
||||||
)
|
)
|
||||||
|
|
||||||
parser = (
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("mode", type=str, required=True, location="json")
|
|
||||||
.add_argument("yaml_content", type=str, location="json")
|
class AppImportPayload(BaseModel):
|
||||||
.add_argument("yaml_url", type=str, location="json")
|
mode: str = Field(..., description="Import mode")
|
||||||
.add_argument("name", type=str, location="json")
|
yaml_content: str | None = None
|
||||||
.add_argument("description", type=str, location="json")
|
yaml_url: str | None = None
|
||||||
.add_argument("icon_type", type=str, location="json")
|
name: str | None = None
|
||||||
.add_argument("icon", type=str, location="json")
|
description: str | None = None
|
||||||
.add_argument("icon_background", type=str, location="json")
|
icon_type: str | None = None
|
||||||
.add_argument("app_id", type=str, location="json")
|
icon: str | None = None
|
||||||
|
icon_background: str | None = None
|
||||||
|
app_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/imports")
|
@console_ns.route("/apps/imports")
|
||||||
class AppImportApi(Resource):
|
class AppImportApi(Resource):
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -61,7 +68,7 @@ class AppImportApi(Resource):
|
||||||
def post(self):
|
def post(self):
|
||||||
# Check user role first
|
# Check user role first
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
args = parser.parse_args()
|
args = AppImportPayload.model_validate(console_ns.payload)
|
||||||
|
|
||||||
# Create service with session
|
# Create service with session
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
|
|
@ -70,15 +77,15 @@ class AppImportApi(Resource):
|
||||||
account = current_user
|
account = current_user
|
||||||
result = import_service.import_app(
|
result = import_service.import_app(
|
||||||
account=account,
|
account=account,
|
||||||
import_mode=args["mode"],
|
import_mode=args.mode,
|
||||||
yaml_content=args.get("yaml_content"),
|
yaml_content=args.yaml_content,
|
||||||
yaml_url=args.get("yaml_url"),
|
yaml_url=args.yaml_url,
|
||||||
name=args.get("name"),
|
name=args.name,
|
||||||
description=args.get("description"),
|
description=args.description,
|
||||||
icon_type=args.get("icon_type"),
|
icon_type=args.icon_type,
|
||||||
icon=args.get("icon"),
|
icon=args.icon,
|
||||||
icon_background=args.get("icon_background"),
|
icon_background=args.icon_background,
|
||||||
app_id=args.get("app_id"),
|
app_id=args.app_id,
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
|
@ -32,6 +33,27 @@ from services.errors.audio import (
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class TextToSpeechPayload(BaseModel):
|
||||||
|
message_id: str | None = Field(default=None, description="Message ID")
|
||||||
|
text: str = Field(..., description="Text to convert")
|
||||||
|
voice: str | None = Field(default=None, description="Voice name")
|
||||||
|
streaming: bool | None = Field(default=None, description="Whether to stream audio")
|
||||||
|
|
||||||
|
|
||||||
|
class TextToSpeechVoiceQuery(BaseModel):
|
||||||
|
language: str = Field(..., description="Language code")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
|
)
|
||||||
|
console_ns.schema_model(
|
||||||
|
TextToSpeechVoiceQuery.__name__,
|
||||||
|
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
||||||
|
|
@ -92,17 +114,7 @@ class ChatMessageTextApi(Resource):
|
||||||
@console_ns.doc("chat_message_text_to_speech")
|
@console_ns.doc("chat_message_text_to_speech")
|
||||||
@console_ns.doc(description="Convert text to speech for chat messages")
|
@console_ns.doc(description="Convert text to speech for chat messages")
|
||||||
@console_ns.doc(params={"app_id": "App ID"})
|
@console_ns.doc(params={"app_id": "App ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[TextToSpeechPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"TextToSpeechRequest",
|
|
||||||
{
|
|
||||||
"message_id": fields.String(description="Message ID"),
|
|
||||||
"text": fields.String(required=True, description="Text to convert to speech"),
|
|
||||||
"voice": fields.String(description="Voice to use for TTS"),
|
|
||||||
"streaming": fields.Boolean(description="Whether to stream the audio"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Text to speech conversion successful")
|
@console_ns.response(200, "Text to speech conversion successful")
|
||||||
@console_ns.response(400, "Bad request - Invalid parameters")
|
@console_ns.response(400, "Bad request - Invalid parameters")
|
||||||
@get_app_model
|
@get_app_model
|
||||||
|
|
@ -111,21 +123,14 @@ class ChatMessageTextApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
try:
|
try:
|
||||||
parser = (
|
payload = TextToSpeechPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("message_id", type=str, location="json")
|
|
||||||
.add_argument("text", type=str, location="json")
|
|
||||||
.add_argument("voice", type=str, location="json")
|
|
||||||
.add_argument("streaming", type=bool, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
message_id = args.get("message_id", None)
|
|
||||||
text = args.get("text", None)
|
|
||||||
voice = args.get("voice", None)
|
|
||||||
|
|
||||||
response = AudioService.transcript_tts(
|
response = AudioService.transcript_tts(
|
||||||
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
|
app_model=app_model,
|
||||||
|
text=payload.text,
|
||||||
|
voice=payload.voice,
|
||||||
|
message_id=payload.message_id,
|
||||||
|
is_draft=True,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
|
@ -159,9 +164,7 @@ class TextModesApi(Resource):
|
||||||
@console_ns.doc("get_text_to_speech_voices")
|
@console_ns.doc("get_text_to_speech_voices")
|
||||||
@console_ns.doc(description="Get available TTS voices for a specific language")
|
@console_ns.doc(description="Get available TTS voices for a specific language")
|
||||||
@console_ns.doc(params={"app_id": "App ID"})
|
@console_ns.doc(params={"app_id": "App ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__])
|
||||||
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
|
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
|
||||||
)
|
)
|
||||||
|
|
@ -172,12 +175,11 @@ class TextModesApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
|
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
response = AudioService.transcript_tts_voices(
|
response = AudioService.transcript_tts_voices(
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
language=args["language"],
|
language=args.language,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import json
|
import json
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -12,6 +13,8 @@ from fields.app_fields import app_server_fields
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.model import AppMCPServer
|
from models.model import AppMCPServer
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||||
app_server_model = console_ns.model("AppServer", app_server_fields)
|
app_server_model = console_ns.model("AppServer", app_server_fields)
|
||||||
|
|
||||||
|
|
@ -21,6 +24,22 @@ class AppMCPServerStatus(StrEnum):
|
||||||
INACTIVE = "inactive"
|
INACTIVE = "inactive"
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerCreatePayload(BaseModel):
|
||||||
|
description: str | None = Field(default=None, description="Server description")
|
||||||
|
parameters: dict = Field(..., description="Server parameters configuration")
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerUpdatePayload(BaseModel):
|
||||||
|
id: str = Field(..., description="Server ID")
|
||||||
|
description: str | None = Field(default=None, description="Server description")
|
||||||
|
parameters: dict = Field(..., description="Server parameters configuration")
|
||||||
|
status: str | None = Field(default=None, description="Server status")
|
||||||
|
|
||||||
|
|
||||||
|
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
|
||||||
|
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/server")
|
@console_ns.route("/apps/<uuid:app_id>/server")
|
||||||
class AppMCPServerController(Resource):
|
class AppMCPServerController(Resource):
|
||||||
@console_ns.doc("get_app_mcp_server")
|
@console_ns.doc("get_app_mcp_server")
|
||||||
|
|
@ -39,15 +58,7 @@ class AppMCPServerController(Resource):
|
||||||
@console_ns.doc("create_app_mcp_server")
|
@console_ns.doc("create_app_mcp_server")
|
||||||
@console_ns.doc(description="Create MCP server configuration for an application")
|
@console_ns.doc(description="Create MCP server configuration for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"MCPServerCreateRequest",
|
|
||||||
{
|
|
||||||
"description": fields.String(description="Server description"),
|
|
||||||
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
|
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -58,21 +69,16 @@ class AppMCPServerController(Resource):
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
parser = (
|
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("description", type=str, required=False, location="json")
|
|
||||||
.add_argument("parameters", type=dict, required=True, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
description = args.get("description")
|
description = payload.description
|
||||||
if not description:
|
if not description:
|
||||||
description = app_model.description or ""
|
description = app_model.description or ""
|
||||||
|
|
||||||
server = AppMCPServer(
|
server = AppMCPServer(
|
||||||
name=app_model.name,
|
name=app_model.name,
|
||||||
description=description,
|
description=description,
|
||||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
parameters=json.dumps(payload.parameters, ensure_ascii=False),
|
||||||
status=AppMCPServerStatus.ACTIVE,
|
status=AppMCPServerStatus.ACTIVE,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
|
|
@ -85,17 +91,7 @@ class AppMCPServerController(Resource):
|
||||||
@console_ns.doc("update_app_mcp_server")
|
@console_ns.doc("update_app_mcp_server")
|
||||||
@console_ns.doc(description="Update MCP server configuration for an application")
|
@console_ns.doc(description="Update MCP server configuration for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"MCPServerUpdateRequest",
|
|
||||||
{
|
|
||||||
"id": fields.String(required=True, description="Server ID"),
|
|
||||||
"description": fields.String(description="Server description"),
|
|
||||||
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
|
|
||||||
"status": fields.String(description="Server status"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
|
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@console_ns.response(404, "Server not found")
|
@console_ns.response(404, "Server not found")
|
||||||
|
|
@ -106,19 +102,12 @@ class AppMCPServerController(Resource):
|
||||||
@marshal_with(app_server_model)
|
@marshal_with(app_server_model)
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
def put(self, app_model):
|
def put(self, app_model):
|
||||||
parser = (
|
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
|
||||||
.add_argument("id", type=str, required=True, location="json")
|
|
||||||
.add_argument("description", type=str, required=False, location="json")
|
|
||||||
.add_argument("parameters", type=dict, required=True, location="json")
|
|
||||||
.add_argument("status", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
|
|
||||||
if not server:
|
if not server:
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
|
|
||||||
description = args.get("description")
|
description = payload.description
|
||||||
if description is None:
|
if description is None:
|
||||||
pass
|
pass
|
||||||
elif not description:
|
elif not description:
|
||||||
|
|
@ -126,11 +115,11 @@ class AppMCPServerController(Resource):
|
||||||
else:
|
else:
|
||||||
server.description = description
|
server.description = description
|
||||||
|
|
||||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
|
||||||
if args["status"]:
|
if payload.status:
|
||||||
if args["status"] not in [status.value for status in AppMCPServerStatus]:
|
if payload.status not in [status.value for status in AppMCPServerStatus]:
|
||||||
raise ValueError("Invalid status")
|
raise ValueError("Invalid status")
|
||||||
server.status = args["status"]
|
server.status = payload.status
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return server
|
return server
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
from flask_restx import Resource, fields, reqparse
|
from typing import Any
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import BadRequest
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -7,6 +11,26 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.ops_service import OpsService
|
from services.ops_service import OpsService
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class TraceProviderQuery(BaseModel):
|
||||||
|
tracing_provider: str = Field(..., description="Tracing provider name")
|
||||||
|
|
||||||
|
|
||||||
|
class TraceConfigPayload(BaseModel):
|
||||||
|
tracing_provider: str = Field(..., description="Tracing provider name")
|
||||||
|
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
TraceProviderQuery.__name__,
|
||||||
|
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
console_ns.schema_model(
|
||||||
|
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/trace-config")
|
@console_ns.route("/apps/<uuid:app_id>/trace-config")
|
||||||
class TraceAppConfigApi(Resource):
|
class TraceAppConfigApi(Resource):
|
||||||
|
|
@ -17,11 +41,7 @@ class TraceAppConfigApi(Resource):
|
||||||
@console_ns.doc("get_trace_app_config")
|
@console_ns.doc("get_trace_app_config")
|
||||||
@console_ns.doc(description="Get tracing configuration for an application")
|
@console_ns.doc(description="Get tracing configuration for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
|
||||||
console_ns.parser().add_argument(
|
|
||||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
|
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
|
||||||
)
|
)
|
||||||
|
|
@ -30,11 +50,10 @@ class TraceAppConfigApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||||
if not trace_config:
|
if not trace_config:
|
||||||
return {"has_not_configured": True}
|
return {"has_not_configured": True}
|
||||||
return trace_config
|
return trace_config
|
||||||
|
|
@ -44,15 +63,7 @@ class TraceAppConfigApi(Resource):
|
||||||
@console_ns.doc("create_trace_app_config")
|
@console_ns.doc("create_trace_app_config")
|
||||||
@console_ns.doc(description="Create a new tracing configuration for an application")
|
@console_ns.doc(description="Create a new tracing configuration for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"TraceConfigCreateRequest",
|
|
||||||
{
|
|
||||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
|
||||||
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
|
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
|
||||||
)
|
)
|
||||||
|
|
@ -62,16 +73,11 @@ class TraceAppConfigApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
"""Create a new trace app configuration"""
|
"""Create a new trace app configuration"""
|
||||||
parser = (
|
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
|
||||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = OpsService.create_tracing_app_config(
|
result = OpsService.create_tracing_app_config(
|
||||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||||
)
|
)
|
||||||
if not result:
|
if not result:
|
||||||
raise TracingConfigIsExist()
|
raise TracingConfigIsExist()
|
||||||
|
|
@ -84,15 +90,7 @@ class TraceAppConfigApi(Resource):
|
||||||
@console_ns.doc("update_trace_app_config")
|
@console_ns.doc("update_trace_app_config")
|
||||||
@console_ns.doc(description="Update an existing tracing configuration for an application")
|
@console_ns.doc(description="Update an existing tracing configuration for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"TraceConfigUpdateRequest",
|
|
||||||
{
|
|
||||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
|
||||||
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
||||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -100,16 +98,11 @@ class TraceAppConfigApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, app_id):
|
def patch(self, app_id):
|
||||||
"""Update an existing trace app configuration"""
|
"""Update an existing trace app configuration"""
|
||||||
parser = (
|
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
|
||||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = OpsService.update_tracing_app_config(
|
result = OpsService.update_tracing_app_config(
|
||||||
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
|
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||||
)
|
)
|
||||||
if not result:
|
if not result:
|
||||||
raise TracingConfigNotExist()
|
raise TracingConfigNotExist()
|
||||||
|
|
@ -120,11 +113,7 @@ class TraceAppConfigApi(Resource):
|
||||||
@console_ns.doc("delete_trace_app_config")
|
@console_ns.doc("delete_trace_app_config")
|
||||||
@console_ns.doc(description="Delete an existing tracing configuration for an application")
|
@console_ns.doc(description="Delete an existing tracing configuration for an application")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
|
||||||
console_ns.parser().add_argument(
|
|
||||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(204, "Tracing configuration deleted successfully")
|
@console_ns.response(204, "Tracing configuration deleted successfully")
|
||||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -132,11 +121,10 @@ class TraceAppConfigApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, app_id):
|
def delete(self, app_id):
|
||||||
"""Delete an existing trace app configuration"""
|
"""Delete an existing trace app configuration"""
|
||||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
|
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||||
if not result:
|
if not result:
|
||||||
raise TracingConfigNotExist()
|
raise TracingConfigNotExist()
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,7 @@
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from typing import Literal
|
||||||
|
|
||||||
|
from flask_restx import Resource, marshal_with
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
|
|
@ -16,69 +19,50 @@ from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Site
|
from models import Site
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class AppSiteUpdatePayload(BaseModel):
|
||||||
|
title: str | None = Field(default=None)
|
||||||
|
icon_type: str | None = Field(default=None)
|
||||||
|
icon: str | None = Field(default=None)
|
||||||
|
icon_background: str | None = Field(default=None)
|
||||||
|
description: str | None = Field(default=None)
|
||||||
|
default_language: str | None = Field(default=None)
|
||||||
|
chat_color_theme: str | None = Field(default=None)
|
||||||
|
chat_color_theme_inverted: bool | None = Field(default=None)
|
||||||
|
customize_domain: str | None = Field(default=None)
|
||||||
|
copyright: str | None = Field(default=None)
|
||||||
|
privacy_policy: str | None = Field(default=None)
|
||||||
|
custom_disclaimer: str | None = Field(default=None)
|
||||||
|
customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None)
|
||||||
|
prompt_public: bool | None = Field(default=None)
|
||||||
|
show_workflow_steps: bool | None = Field(default=None)
|
||||||
|
use_icon_as_answer_icon: bool | None = Field(default=None)
|
||||||
|
|
||||||
|
@field_validator("default_language")
|
||||||
|
@classmethod
|
||||||
|
def validate_language(cls, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return supported_language(value)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
AppSiteUpdatePayload.__name__,
|
||||||
|
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||||
app_site_model = console_ns.model("AppSite", app_site_fields)
|
app_site_model = console_ns.model("AppSite", app_site_fields)
|
||||||
|
|
||||||
|
|
||||||
def parse_app_site_args():
|
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("title", type=str, required=False, location="json")
|
|
||||||
.add_argument("icon_type", type=str, required=False, location="json")
|
|
||||||
.add_argument("icon", type=str, required=False, location="json")
|
|
||||||
.add_argument("icon_background", type=str, required=False, location="json")
|
|
||||||
.add_argument("description", type=str, required=False, location="json")
|
|
||||||
.add_argument("default_language", type=supported_language, required=False, location="json")
|
|
||||||
.add_argument("chat_color_theme", type=str, required=False, location="json")
|
|
||||||
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
|
||||||
.add_argument("customize_domain", type=str, required=False, location="json")
|
|
||||||
.add_argument("copyright", type=str, required=False, location="json")
|
|
||||||
.add_argument("privacy_policy", type=str, required=False, location="json")
|
|
||||||
.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
|
||||||
.add_argument(
|
|
||||||
"customize_token_strategy",
|
|
||||||
type=str,
|
|
||||||
choices=["must", "allow", "not_allow"],
|
|
||||||
required=False,
|
|
||||||
location="json",
|
|
||||||
)
|
|
||||||
.add_argument("prompt_public", type=bool, required=False, location="json")
|
|
||||||
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
|
||||||
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
|
||||||
)
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/site")
|
@console_ns.route("/apps/<uuid:app_id>/site")
|
||||||
class AppSite(Resource):
|
class AppSite(Resource):
|
||||||
@console_ns.doc("update_app_site")
|
@console_ns.doc("update_app_site")
|
||||||
@console_ns.doc(description="Update application site configuration")
|
@console_ns.doc(description="Update application site configuration")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"AppSiteRequest",
|
|
||||||
{
|
|
||||||
"title": fields.String(description="Site title"),
|
|
||||||
"icon_type": fields.String(description="Icon type"),
|
|
||||||
"icon": fields.String(description="Icon"),
|
|
||||||
"icon_background": fields.String(description="Icon background color"),
|
|
||||||
"description": fields.String(description="Site description"),
|
|
||||||
"default_language": fields.String(description="Default language"),
|
|
||||||
"chat_color_theme": fields.String(description="Chat color theme"),
|
|
||||||
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
|
|
||||||
"customize_domain": fields.String(description="Custom domain"),
|
|
||||||
"copyright": fields.String(description="Copyright text"),
|
|
||||||
"privacy_policy": fields.String(description="Privacy policy"),
|
|
||||||
"custom_disclaimer": fields.String(description="Custom disclaimer"),
|
|
||||||
"customize_token_strategy": fields.String(
|
|
||||||
enum=["must", "allow", "not_allow"], description="Token strategy"
|
|
||||||
),
|
|
||||||
"prompt_public": fields.Boolean(description="Make prompt public"),
|
|
||||||
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
|
|
||||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
|
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
|
||||||
@console_ns.response(403, "Insufficient permissions")
|
@console_ns.response(403, "Insufficient permissions")
|
||||||
@console_ns.response(404, "App not found")
|
@console_ns.response(404, "App not found")
|
||||||
|
|
@ -89,7 +73,7 @@ class AppSite(Resource):
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(app_site_model)
|
@marshal_with(app_site_model)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
args = parse_app_site_args()
|
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||||
if not site:
|
if not site:
|
||||||
|
|
@ -113,7 +97,7 @@ class AppSite(Resource):
|
||||||
"show_workflow_steps",
|
"show_workflow_steps",
|
||||||
"use_icon_as_answer_icon",
|
"use_icon_as_answer_icon",
|
||||||
]:
|
]:
|
||||||
value = args.get(attr_name)
|
value = getattr(args, attr_name)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
setattr(site, attr_name, value)
|
setattr(site, attr_name, value)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import NoReturn, ParamSpec, TypeVar
|
from typing import Any, NoReturn, ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import Response
|
from flask import Response, request
|
||||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -29,6 +30,27 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowDraftVariableListQuery(BaseModel):
|
||||||
|
page: int = Field(default=1, ge=1, le=100_000, description="Page number")
|
||||||
|
limit: int = Field(default=20, ge=1, le=100, description="Items per page")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowDraftVariableUpdatePayload(BaseModel):
|
||||||
|
name: str | None = Field(default=None, description="Variable name")
|
||||||
|
value: Any | None = Field(default=None, description="Variable value")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
WorkflowDraftVariableListQuery.__name__,
|
||||||
|
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
console_ns.schema_model(
|
||||||
|
WorkflowDraftVariableUpdatePayload.__name__,
|
||||||
|
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _convert_values_to_json_serializable_object(value: Segment):
|
def _convert_values_to_json_serializable_object(value: Segment):
|
||||||
|
|
@ -57,22 +79,6 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||||
return _convert_values_to_json_serializable_object(value)
|
return _convert_values_to_json_serializable_object(value)
|
||||||
|
|
||||||
|
|
||||||
def _create_pagination_parser():
|
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(
|
|
||||||
"page",
|
|
||||||
type=inputs.int_range(1, 100_000),
|
|
||||||
required=False,
|
|
||||||
default=1,
|
|
||||||
location="args",
|
|
||||||
help="the page of data requested",
|
|
||||||
)
|
|
||||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||||
value_type = workflow_draft_var.value_type
|
value_type = workflow_draft_var.value_type
|
||||||
return value_type.exposed_type().value
|
return value_type.exposed_type().value
|
||||||
|
|
@ -201,7 +207,7 @@ def _api_prerequisite(f: Callable[P, R]):
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
|
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
|
||||||
class WorkflowVariableCollectionApi(Resource):
|
class WorkflowVariableCollectionApi(Resource):
|
||||||
@console_ns.expect(_create_pagination_parser())
|
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
|
||||||
@console_ns.doc("get_workflow_variables")
|
@console_ns.doc("get_workflow_variables")
|
||||||
@console_ns.doc(description="Get draft workflow variables")
|
@console_ns.doc(description="Get draft workflow variables")
|
||||||
@console_ns.doc(params={"app_id": "Application ID"})
|
@console_ns.doc(params={"app_id": "Application ID"})
|
||||||
|
|
@ -215,8 +221,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||||
"""
|
"""
|
||||||
Get draft workflow
|
Get draft workflow
|
||||||
"""
|
"""
|
||||||
parser = _create_pagination_parser()
|
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# fetch draft workflow by app_model
|
# fetch draft workflow by app_model
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
|
|
@ -323,15 +328,7 @@ class VariableApi(Resource):
|
||||||
|
|
||||||
@console_ns.doc("update_variable")
|
@console_ns.doc("update_variable")
|
||||||
@console_ns.doc(description="Update a workflow variable")
|
@console_ns.doc(description="Update a workflow variable")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"UpdateVariableRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(description="Variable name"),
|
|
||||||
"value": fields.Raw(description="Variable value"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||||
@console_ns.response(404, "Variable not found")
|
@console_ns.response(404, "Variable not found")
|
||||||
@_api_prerequisite
|
@_api_prerequisite
|
||||||
|
|
@ -358,16 +355,10 @@ class VariableApi(Resource):
|
||||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||||
# }
|
# }
|
||||||
|
|
||||||
parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
|
||||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
draft_var_srv = WorkflowDraftVariableService(
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
session=db.session(),
|
session=db.session(),
|
||||||
)
|
)
|
||||||
args = parser.parse_args(strict=True)
|
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
|
|
||||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
if variable is None:
|
if variable is None:
|
||||||
|
|
@ -375,8 +366,8 @@ class VariableApi(Resource):
|
||||||
if variable.app_id != app_model.id:
|
if variable.app_id != app_model.id:
|
||||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
|
||||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
new_name = args_model.name
|
||||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
raw_value = args_model.value
|
||||||
if new_name is None and raw_value is None:
|
if new_name is None and raw_value is None:
|
||||||
return variable
|
return variable
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,53 @@
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.error import AlreadyActivateError
|
from controllers.console.error import AlreadyActivateError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||||
from models import AccountStatus
|
from models import AccountStatus
|
||||||
from services.account_service import AccountService, RegisterService
|
from services.account_service import AccountService, RegisterService
|
||||||
|
|
||||||
active_check_parser = (
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
|
|
||||||
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
|
class ActivateCheckQuery(BaseModel):
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
|
workspace_id: str | None = Field(default=None)
|
||||||
)
|
email: EmailStr | None = Field(default=None)
|
||||||
|
token: str
|
||||||
|
|
||||||
|
|
||||||
|
class ActivatePayload(BaseModel):
|
||||||
|
workspace_id: str | None = Field(default=None)
|
||||||
|
email: EmailStr | None = Field(default=None)
|
||||||
|
token: str
|
||||||
|
name: str = Field(..., max_length=30)
|
||||||
|
interface_language: str = Field(...)
|
||||||
|
timezone: str = Field(...)
|
||||||
|
|
||||||
|
@field_validator("interface_language")
|
||||||
|
@classmethod
|
||||||
|
def validate_lang(cls, value: str) -> str:
|
||||||
|
return supported_language(value)
|
||||||
|
|
||||||
|
@field_validator("timezone")
|
||||||
|
@classmethod
|
||||||
|
def validate_tz(cls, value: str) -> str:
|
||||||
|
return timezone(value)
|
||||||
|
|
||||||
|
|
||||||
|
for model in (ActivateCheckQuery, ActivatePayload):
|
||||||
|
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/activate/check")
|
@console_ns.route("/activate/check")
|
||||||
class ActivateCheckApi(Resource):
|
class ActivateCheckApi(Resource):
|
||||||
@console_ns.doc("check_activation_token")
|
@console_ns.doc("check_activation_token")
|
||||||
@console_ns.doc(description="Check if activation token is valid")
|
@console_ns.doc(description="Check if activation token is valid")
|
||||||
@console_ns.expect(active_check_parser)
|
@console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Success",
|
"Success",
|
||||||
|
|
@ -35,11 +60,11 @@ class ActivateCheckApi(Resource):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def get(self):
|
def get(self):
|
||||||
args = active_check_parser.parse_args()
|
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
|
||||||
workspaceId = args["workspace_id"]
|
workspaceId = args.workspace_id
|
||||||
reg_email = args["email"]
|
reg_email = args.email
|
||||||
token = args["token"]
|
token = args.token
|
||||||
|
|
||||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||||
if invitation:
|
if invitation:
|
||||||
|
|
@ -56,22 +81,11 @@ class ActivateCheckApi(Resource):
|
||||||
return {"is_valid": False}
|
return {"is_valid": False}
|
||||||
|
|
||||||
|
|
||||||
active_parser = (
|
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
|
||||||
.add_argument("email", type=email, required=False, nullable=True, location="json")
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
|
||||||
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/activate")
|
@console_ns.route("/activate")
|
||||||
class ActivateApi(Resource):
|
class ActivateApi(Resource):
|
||||||
@console_ns.doc("activate_account")
|
@console_ns.doc("activate_account")
|
||||||
@console_ns.doc(description="Activate account with invitation token")
|
@console_ns.doc(description="Activate account with invitation token")
|
||||||
@console_ns.expect(active_parser)
|
@console_ns.expect(console_ns.models[ActivatePayload.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Account activated successfully",
|
"Account activated successfully",
|
||||||
|
|
@ -85,19 +99,19 @@ class ActivateApi(Resource):
|
||||||
)
|
)
|
||||||
@console_ns.response(400, "Already activated or invalid token")
|
@console_ns.response(400, "Already activated or invalid token")
|
||||||
def post(self):
|
def post(self):
|
||||||
args = active_parser.parse_args()
|
args = ActivatePayload.model_validate(console_ns.payload)
|
||||||
|
|
||||||
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
|
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
|
||||||
if invitation is None:
|
if invitation is None:
|
||||||
raise AlreadyActivateError()
|
raise AlreadyActivateError()
|
||||||
|
|
||||||
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
|
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
|
||||||
|
|
||||||
account = invitation["account"]
|
account = invitation["account"]
|
||||||
account.name = args["name"]
|
account.name = args.name
|
||||||
|
|
||||||
account.interface_language = args["interface_language"]
|
account.interface_language = args.interface_language
|
||||||
account.timezone = args["timezone"]
|
account.timezone = args.timezone
|
||||||
account.interface_theme = "light"
|
account.interface_theme = "light"
|
||||||
account.status = AccountStatus.ACTIVE
|
account.status = AccountStatus.ACTIVE
|
||||||
account.initialized_at = naive_utc_now()
|
account.initialized_at = naive_utc_now()
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,26 @@
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from controllers.console import console_ns
|
|
||||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
|
||||||
from controllers.console.wraps import is_admin_or_owner_required
|
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||||
|
|
||||||
from ..wraps import account_initialization_required, setup_required
|
from .. import console_ns
|
||||||
|
from ..auth.error import ApiKeyAuthFailedError
|
||||||
|
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyAuthBindingPayload(BaseModel):
|
||||||
|
category: str = Field(...)
|
||||||
|
provider: str = Field(...)
|
||||||
|
credentials: dict = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
ApiKeyAuthBindingPayload.__name__,
|
||||||
|
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/api-key-auth/data-source")
|
@console_ns.route("/api-key-auth/data-source")
|
||||||
|
|
@ -40,19 +54,15 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@is_admin_or_owner_required
|
@is_admin_or_owner_required
|
||||||
|
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the table must be admin or owner
|
# The role of the current user in the table must be admin or owner
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
parser = (
|
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
data = payload.model_dump()
|
||||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
ApiKeyAuthService.validate_api_key_auth_args(data)
|
||||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
|
||||||
try:
|
try:
|
||||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ApiKeyAuthFailedError(str(e))
|
raise ApiKeyAuthFailedError(str(e))
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,11 @@ from flask import current_app, redirect, request
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.console import console_ns
|
|
||||||
from controllers.console.wraps import is_admin_or_owner_required
|
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from libs.oauth_data_source import NotionOAuth
|
from libs.oauth_data_source import NotionOAuth
|
||||||
|
|
||||||
from ..wraps import account_initialization_required, setup_required
|
from .. import console_ns
|
||||||
|
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
@ -14,16 +15,45 @@ from controllers.console.auth.error import (
|
||||||
InvalidTokenError,
|
InvalidTokenError,
|
||||||
PasswordMismatchError,
|
PasswordMismatchError,
|
||||||
)
|
)
|
||||||
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
|
|
||||||
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import EmailStr, extract_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models import Account
|
from models import Account
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||||
|
|
||||||
|
from ..error import AccountInFreezeError, EmailSendIpLimitError
|
||||||
|
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class EmailRegisterSendPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(..., description="Email address")
|
||||||
|
language: str | None = Field(default=None, description="Language code")
|
||||||
|
|
||||||
|
|
||||||
|
class EmailRegisterValidityPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(...)
|
||||||
|
code: str = Field(...)
|
||||||
|
token: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class EmailRegisterResetPayload(BaseModel):
|
||||||
|
token: str = Field(...)
|
||||||
|
new_password: str = Field(...)
|
||||||
|
password_confirm: str = Field(...)
|
||||||
|
|
||||||
|
@field_validator("new_password", "password_confirm")
|
||||||
|
@classmethod
|
||||||
|
def validate_password(cls, value: str) -> str:
|
||||||
|
return valid_password(value)
|
||||||
|
|
||||||
|
|
||||||
|
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
|
||||||
|
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/email-register/send-email")
|
@console_ns.route("/email-register/send-email")
|
||||||
class EmailRegisterSendEmailApi(Resource):
|
class EmailRegisterSendEmailApi(Resource):
|
||||||
|
|
@ -31,27 +61,22 @@ class EmailRegisterSendEmailApi(Resource):
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
@email_register_enabled
|
@email_register_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=email, required=True, location="json")
|
|
||||||
.add_argument("language", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
if AccountService.is_email_send_ip_limit(ip_address):
|
if AccountService.is_email_send_ip_limit(ip_address):
|
||||||
raise EmailSendIpLimitError()
|
raise EmailSendIpLimitError()
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
if args["language"] in languages:
|
if args.language in languages:
|
||||||
language = args["language"]
|
language = args.language
|
||||||
|
|
||||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||||
token = None
|
token = None
|
||||||
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language)
|
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
|
||||||
return {"result": "success", "data": token}
|
return {"result": "success", "data": token}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -61,40 +86,34 @@ class EmailRegisterCheckApi(Resource):
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
@email_register_enabled
|
@email_register_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=str, required=True, location="json")
|
|
||||||
.add_argument("code", type=str, required=True, location="json")
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args.email
|
||||||
|
|
||||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"])
|
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
|
||||||
if is_email_register_error_rate_limit:
|
if is_email_register_error_rate_limit:
|
||||||
raise EmailRegisterLimitError()
|
raise EmailRegisterLimitError()
|
||||||
|
|
||||||
token_data = AccountService.get_email_register_data(args["token"])
|
token_data = AccountService.get_email_register_data(args.token)
|
||||||
if token_data is None:
|
if token_data is None:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
if user_email != token_data.get("email"):
|
if user_email != token_data.get("email"):
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
|
|
||||||
if args["code"] != token_data.get("code"):
|
if args.code != token_data.get("code"):
|
||||||
AccountService.add_email_register_error_rate_limit(args["email"])
|
AccountService.add_email_register_error_rate_limit(args.email)
|
||||||
raise EmailCodeError()
|
raise EmailCodeError()
|
||||||
|
|
||||||
# Verified, revoke the first token
|
# Verified, revoke the first token
|
||||||
AccountService.revoke_email_register_token(args["token"])
|
AccountService.revoke_email_register_token(args.token)
|
||||||
|
|
||||||
# Refresh token data by generating a new token
|
# Refresh token data by generating a new token
|
||||||
_, new_token = AccountService.generate_email_register_token(
|
_, new_token = AccountService.generate_email_register_token(
|
||||||
user_email, code=args["code"], additional_data={"phase": "register"}
|
user_email, code=args.code, additional_data={"phase": "register"}
|
||||||
)
|
)
|
||||||
|
|
||||||
AccountService.reset_email_register_error_rate_limit(args["email"])
|
AccountService.reset_email_register_error_rate_limit(args.email)
|
||||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -104,20 +123,14 @@ class EmailRegisterResetApi(Resource):
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
@email_register_enabled
|
@email_register_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = EmailRegisterResetPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Validate passwords match
|
# Validate passwords match
|
||||||
if args["new_password"] != args["password_confirm"]:
|
if args.new_password != args.password_confirm:
|
||||||
raise PasswordMismatchError()
|
raise PasswordMismatchError()
|
||||||
|
|
||||||
# Validate token and get register data
|
# Validate token and get register data
|
||||||
register_data = AccountService.get_email_register_data(args["token"])
|
register_data = AccountService.get_email_register_data(args.token)
|
||||||
if not register_data:
|
if not register_data:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
# Must use token in reset phase
|
# Must use token in reset phase
|
||||||
|
|
@ -125,7 +138,7 @@ class EmailRegisterResetApi(Resource):
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
# Revoke token to prevent reuse
|
# Revoke token to prevent reuse
|
||||||
AccountService.revoke_email_register_token(args["token"])
|
AccountService.revoke_email_register_token(args.token)
|
||||||
|
|
||||||
email = register_data.get("email", "")
|
email = register_data.get("email", "")
|
||||||
|
|
||||||
|
|
@ -135,7 +148,7 @@ class EmailRegisterResetApi(Resource):
|
||||||
if account:
|
if account:
|
||||||
raise EmailAlreadyInUseError()
|
raise EmailAlreadyInUseError()
|
||||||
else:
|
else:
|
||||||
account = self._create_new_account(email, args["password_confirm"])
|
account = self._create_new_account(email, args.password_confirm)
|
||||||
if not account:
|
if not account:
|
||||||
raise AccountNotFoundError()
|
raise AccountNotFoundError()
|
||||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,8 @@ import base64
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
@ -18,26 +19,46 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import EmailStr, extract_remote_ip
|
||||||
from libs.password import hash_password, valid_password
|
from libs.password import hash_password, valid_password
|
||||||
from models import Account
|
from models import Account
|
||||||
from services.account_service import AccountService, TenantService
|
from services.account_service import AccountService, TenantService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordSendPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(...)
|
||||||
|
language: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordCheckPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(...)
|
||||||
|
code: str = Field(...)
|
||||||
|
token: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordResetPayload(BaseModel):
|
||||||
|
token: str = Field(...)
|
||||||
|
new_password: str = Field(...)
|
||||||
|
password_confirm: str = Field(...)
|
||||||
|
|
||||||
|
@field_validator("new_password", "password_confirm")
|
||||||
|
@classmethod
|
||||||
|
def validate_password(cls, value: str) -> str:
|
||||||
|
return valid_password(value)
|
||||||
|
|
||||||
|
|
||||||
|
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
|
||||||
|
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/forgot-password")
|
@console_ns.route("/forgot-password")
|
||||||
class ForgotPasswordSendEmailApi(Resource):
|
class ForgotPasswordSendEmailApi(Resource):
|
||||||
@console_ns.doc("send_forgot_password_email")
|
@console_ns.doc("send_forgot_password_email")
|
||||||
@console_ns.doc(description="Send password reset email")
|
@console_ns.doc(description="Send password reset email")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"ForgotPasswordEmailRequest",
|
|
||||||
{
|
|
||||||
"email": fields.String(required=True, description="Email address"),
|
|
||||||
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Email sent successfully",
|
"Email sent successfully",
|
||||||
|
|
@ -54,28 +75,23 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=email, required=True, location="json")
|
|
||||||
.add_argument("language", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
if AccountService.is_email_send_ip_limit(ip_address):
|
if AccountService.is_email_send_ip_limit(ip_address):
|
||||||
raise EmailSendIpLimitError()
|
raise EmailSendIpLimitError()
|
||||||
|
|
||||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
if args.language is not None and args.language == "zh-Hans":
|
||||||
language = "zh-Hans"
|
language = "zh-Hans"
|
||||||
else:
|
else:
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||||
|
|
||||||
token = AccountService.send_reset_password_email(
|
token = AccountService.send_reset_password_email(
|
||||||
account=account,
|
account=account,
|
||||||
email=args["email"],
|
email=args.email,
|
||||||
language=language,
|
language=language,
|
||||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||||
)
|
)
|
||||||
|
|
@ -87,16 +103,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||||
class ForgotPasswordCheckApi(Resource):
|
class ForgotPasswordCheckApi(Resource):
|
||||||
@console_ns.doc("check_forgot_password_code")
|
@console_ns.doc("check_forgot_password_code")
|
||||||
@console_ns.doc(description="Verify password reset code")
|
@console_ns.doc(description="Verify password reset code")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"ForgotPasswordCheckRequest",
|
|
||||||
{
|
|
||||||
"email": fields.String(required=True, description="Email address"),
|
|
||||||
"code": fields.String(required=True, description="Verification code"),
|
|
||||||
"token": fields.String(required=True, description="Reset token"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Code verified successfully",
|
"Code verified successfully",
|
||||||
|
|
@ -113,40 +120,34 @@ class ForgotPasswordCheckApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=str, required=True, location="json")
|
|
||||||
.add_argument("code", type=str, required=True, location="json")
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args.email
|
||||||
|
|
||||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
|
||||||
if is_forgot_password_error_rate_limit:
|
if is_forgot_password_error_rate_limit:
|
||||||
raise EmailPasswordResetLimitError()
|
raise EmailPasswordResetLimitError()
|
||||||
|
|
||||||
token_data = AccountService.get_reset_password_data(args["token"])
|
token_data = AccountService.get_reset_password_data(args.token)
|
||||||
if token_data is None:
|
if token_data is None:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
if user_email != token_data.get("email"):
|
if user_email != token_data.get("email"):
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
|
|
||||||
if args["code"] != token_data.get("code"):
|
if args.code != token_data.get("code"):
|
||||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
AccountService.add_forgot_password_error_rate_limit(args.email)
|
||||||
raise EmailCodeError()
|
raise EmailCodeError()
|
||||||
|
|
||||||
# Verified, revoke the first token
|
# Verified, revoke the first token
|
||||||
AccountService.revoke_reset_password_token(args["token"])
|
AccountService.revoke_reset_password_token(args.token)
|
||||||
|
|
||||||
# Refresh token data by generating a new token
|
# Refresh token data by generating a new token
|
||||||
_, new_token = AccountService.generate_reset_password_token(
|
_, new_token = AccountService.generate_reset_password_token(
|
||||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
user_email, code=args.code, additional_data={"phase": "reset"}
|
||||||
)
|
)
|
||||||
|
|
||||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
AccountService.reset_forgot_password_error_rate_limit(args.email)
|
||||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -154,16 +155,7 @@ class ForgotPasswordCheckApi(Resource):
|
||||||
class ForgotPasswordResetApi(Resource):
|
class ForgotPasswordResetApi(Resource):
|
||||||
@console_ns.doc("reset_password")
|
@console_ns.doc("reset_password")
|
||||||
@console_ns.doc(description="Reset password with verification token")
|
@console_ns.doc(description="Reset password with verification token")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"ForgotPasswordResetRequest",
|
|
||||||
{
|
|
||||||
"token": fields.String(required=True, description="Verification token"),
|
|
||||||
"new_password": fields.String(required=True, description="New password"),
|
|
||||||
"password_confirm": fields.String(required=True, description="Password confirmation"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Password reset successfully",
|
"Password reset successfully",
|
||||||
|
|
@ -173,20 +165,14 @@ class ForgotPasswordResetApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Validate passwords match
|
# Validate passwords match
|
||||||
if args["new_password"] != args["password_confirm"]:
|
if args.new_password != args.password_confirm:
|
||||||
raise PasswordMismatchError()
|
raise PasswordMismatchError()
|
||||||
|
|
||||||
# Validate token and get reset data
|
# Validate token and get reset data
|
||||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
reset_data = AccountService.get_reset_password_data(args.token)
|
||||||
if not reset_data:
|
if not reset_data:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
# Must use token in reset phase
|
# Must use token in reset phase
|
||||||
|
|
@ -194,11 +180,11 @@ class ForgotPasswordResetApi(Resource):
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
# Revoke token to prevent reuse
|
# Revoke token to prevent reuse
|
||||||
AccountService.revoke_reset_password_token(args["token"])
|
AccountService.revoke_reset_password_token(args.token)
|
||||||
|
|
||||||
# Generate secure salt and hash password
|
# Generate secure salt and hash password
|
||||||
salt = secrets.token_bytes(16)
|
salt = secrets.token_bytes(16)
|
||||||
password_hashed = hash_password(args["new_password"], salt)
|
password_hashed = hash_password(args.new_password, salt)
|
||||||
|
|
||||||
email = reset_data.get("email", "")
|
email = reset_data.get("email", "")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import flask_login
|
import flask_login
|
||||||
from flask import make_response, request
|
from flask import make_response, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
|
@ -23,7 +24,7 @@ from controllers.console.error import (
|
||||||
)
|
)
|
||||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import EmailStr, extract_remote_ip
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
from libs.token import (
|
from libs.token import (
|
||||||
clear_access_token_from_cookie,
|
clear_access_token_from_cookie,
|
||||||
|
|
@ -40,6 +41,36 @@ from services.errors.account import AccountRegisterError
|
||||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class LoginPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(..., description="Email address")
|
||||||
|
password: str = Field(..., description="Password")
|
||||||
|
remember_me: bool = Field(default=False, description="Remember me flag")
|
||||||
|
invite_token: str | None = Field(default=None, description="Invitation token")
|
||||||
|
|
||||||
|
|
||||||
|
class EmailPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(...)
|
||||||
|
language: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class EmailCodeLoginPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(...)
|
||||||
|
code: str = Field(...)
|
||||||
|
token: str = Field(...)
|
||||||
|
language: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
def reg(cls: type[BaseModel]):
|
||||||
|
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
|
reg(LoginPayload)
|
||||||
|
reg(EmailPayload)
|
||||||
|
reg(EmailCodeLoginPayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/login")
|
@console_ns.route("/login")
|
||||||
class LoginApi(Resource):
|
class LoginApi(Resource):
|
||||||
|
|
@ -47,41 +78,36 @@ class LoginApi(Resource):
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
|
@console_ns.expect(console_ns.models[LoginPayload.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Authenticate user and login."""
|
"""Authenticate user and login."""
|
||||||
parser = (
|
args = LoginPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=email, required=True, location="json")
|
|
||||||
.add_argument("password", type=str, required=True, location="json")
|
|
||||||
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
|
||||||
.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
|
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
|
||||||
if is_login_error_rate_limit:
|
if is_login_error_rate_limit:
|
||||||
raise EmailPasswordLoginLimitError()
|
raise EmailPasswordLoginLimitError()
|
||||||
|
|
||||||
invitation = args["invite_token"]
|
# TODO: why invitation is re-assigned with different type?
|
||||||
|
invitation = args.invite_token # type: ignore
|
||||||
if invitation:
|
if invitation:
|
||||||
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
|
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if invitation:
|
if invitation:
|
||||||
data = invitation.get("data", {})
|
data = invitation.get("data", {}) # type: ignore
|
||||||
invitee_email = data.get("email") if data else None
|
invitee_email = data.get("email") if data else None
|
||||||
if invitee_email != args["email"]:
|
if invitee_email != args.email:
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
|
account = AccountService.authenticate(args.email, args.password, args.invite_token)
|
||||||
else:
|
else:
|
||||||
account = AccountService.authenticate(args["email"], args["password"])
|
account = AccountService.authenticate(args.email, args.password)
|
||||||
except services.errors.account.AccountLoginError:
|
except services.errors.account.AccountLoginError:
|
||||||
raise AccountBannedError()
|
raise AccountBannedError()
|
||||||
except services.errors.account.AccountPasswordError:
|
except services.errors.account.AccountPasswordError:
|
||||||
AccountService.add_login_error_rate_limit(args["email"])
|
AccountService.add_login_error_rate_limit(args.email)
|
||||||
raise AuthenticationFailedError()
|
raise AuthenticationFailedError()
|
||||||
# SELF_HOSTED only have one workspace
|
# SELF_HOSTED only have one workspace
|
||||||
tenants = TenantService.get_join_tenants(account)
|
tenants = TenantService.get_join_tenants(account)
|
||||||
|
|
@ -97,7 +123,7 @@ class LoginApi(Resource):
|
||||||
}
|
}
|
||||||
|
|
||||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||||
AccountService.reset_login_error_rate_limit(args["email"])
|
AccountService.reset_login_error_rate_limit(args.email)
|
||||||
|
|
||||||
# Create response with cookies instead of returning tokens in body
|
# Create response with cookies instead of returning tokens in body
|
||||||
response = make_response({"result": "success"})
|
response = make_response({"result": "success"})
|
||||||
|
|
@ -134,25 +160,21 @@ class LogoutApi(Resource):
|
||||||
class ResetPasswordSendEmailApi(Resource):
|
class ResetPasswordSendEmailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
|
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = EmailPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=email, required=True, location="json")
|
|
||||||
.add_argument("language", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
if args.language is not None and args.language == "zh-Hans":
|
||||||
language = "zh-Hans"
|
language = "zh-Hans"
|
||||||
else:
|
else:
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
try:
|
try:
|
||||||
account = AccountService.get_user_through_email(args["email"])
|
account = AccountService.get_user_through_email(args.email)
|
||||||
except AccountRegisterError:
|
except AccountRegisterError:
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
token = AccountService.send_reset_password_email(
|
token = AccountService.send_reset_password_email(
|
||||||
email=args["email"],
|
email=args.email,
|
||||||
account=account,
|
account=account,
|
||||||
language=language,
|
language=language,
|
||||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||||
|
|
@ -164,30 +186,26 @@ class ResetPasswordSendEmailApi(Resource):
|
||||||
@console_ns.route("/email-code-login")
|
@console_ns.route("/email-code-login")
|
||||||
class EmailCodeLoginSendEmailApi(Resource):
|
class EmailCodeLoginSendEmailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = EmailPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=email, required=True, location="json")
|
|
||||||
.add_argument("language", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
if AccountService.is_email_send_ip_limit(ip_address):
|
if AccountService.is_email_send_ip_limit(ip_address):
|
||||||
raise EmailSendIpLimitError()
|
raise EmailSendIpLimitError()
|
||||||
|
|
||||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
if args.language is not None and args.language == "zh-Hans":
|
||||||
language = "zh-Hans"
|
language = "zh-Hans"
|
||||||
else:
|
else:
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
try:
|
try:
|
||||||
account = AccountService.get_user_through_email(args["email"])
|
account = AccountService.get_user_through_email(args.email)
|
||||||
except AccountRegisterError:
|
except AccountRegisterError:
|
||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
if account is None:
|
if account is None:
|
||||||
if FeatureService.get_system_features().is_allow_register:
|
if FeatureService.get_system_features().is_allow_register:
|
||||||
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
|
token = AccountService.send_email_code_login_email(email=args.email, language=language)
|
||||||
else:
|
else:
|
||||||
raise AccountNotFound()
|
raise AccountNotFound()
|
||||||
else:
|
else:
|
||||||
|
|
@ -199,30 +217,24 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||||
@console_ns.route("/email-code-login/validity")
|
@console_ns.route("/email-code-login/validity")
|
||||||
class EmailCodeLoginApi(Resource):
|
class EmailCodeLoginApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=str, required=True, location="json")
|
|
||||||
.add_argument("code", type=str, required=True, location="json")
|
|
||||||
.add_argument("token", type=str, required=True, location="json")
|
|
||||||
.add_argument("language", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args.email
|
||||||
language = args["language"]
|
language = args.language
|
||||||
|
|
||||||
token_data = AccountService.get_email_code_login_data(args["token"])
|
token_data = AccountService.get_email_code_login_data(args.token)
|
||||||
if token_data is None:
|
if token_data is None:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
if token_data["email"] != args["email"]:
|
if token_data["email"] != args.email:
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
|
|
||||||
if token_data["code"] != args["code"]:
|
if token_data["code"] != args.code:
|
||||||
raise EmailCodeError()
|
raise EmailCodeError()
|
||||||
|
|
||||||
AccountService.revoke_email_code_login_token(args["token"])
|
AccountService.revoke_email_code_login_token(args.token)
|
||||||
try:
|
try:
|
||||||
account = AccountService.get_user_through_email(user_email)
|
account = AccountService.get_user_through_email(user_email)
|
||||||
except AccountRegisterError:
|
except AccountRegisterError:
|
||||||
|
|
@ -255,7 +267,7 @@ class EmailCodeLoginApi(Resource):
|
||||||
except WorkspacesLimitExceededError:
|
except WorkspacesLimitExceededError:
|
||||||
raise WorkspacesLimitExceeded()
|
raise WorkspacesLimitExceeded()
|
||||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||||
AccountService.reset_login_error_rate_limit(args["email"])
|
AccountService.reset_login_error_rate_limit(args.email)
|
||||||
|
|
||||||
# Create response with cookies instead of returning tokens in body
|
# Create response with cookies instead of returning tokens in body
|
||||||
response = make_response({"result": "success"})
|
response = make_response({"result": "success"})
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,8 @@ from functools import wraps
|
||||||
from typing import Concatenate, ParamSpec, TypeVar
|
from typing import Concatenate, ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask import jsonify, request
|
from flask import jsonify, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
|
@ -20,15 +21,34 @@ R = TypeVar("R")
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientPayload(BaseModel):
|
||||||
|
client_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProviderRequest(BaseModel):
|
||||||
|
client_id: str
|
||||||
|
redirect_uri: str
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthTokenRequest(BaseModel):
|
||||||
|
client_id: str
|
||||||
|
grant_type: str
|
||||||
|
code: str | None = None
|
||||||
|
client_secret: str | None = None
|
||||||
|
redirect_uri: str | None = None
|
||||||
|
refresh_token: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||||
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
|
json_data = request.get_json()
|
||||||
parsed_args = parser.parse_args()
|
if json_data is None:
|
||||||
client_id = parsed_args.get("client_id")
|
|
||||||
if not client_id:
|
|
||||||
raise BadRequest("client_id is required")
|
raise BadRequest("client_id is required")
|
||||||
|
|
||||||
|
payload = OAuthClientPayload.model_validate(json_data)
|
||||||
|
client_id = payload.client_id
|
||||||
|
|
||||||
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
||||||
if not oauth_provider_app:
|
if not oauth_provider_app:
|
||||||
raise NotFound("client_id is invalid")
|
raise NotFound("client_id is invalid")
|
||||||
|
|
@ -89,9 +109,8 @@ class OAuthServerAppApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@oauth_server_client_id_required
|
@oauth_server_client_id_required
|
||||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||||
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
|
payload = OAuthProviderRequest.model_validate(request.get_json())
|
||||||
parsed_args = parser.parse_args()
|
redirect_uri = payload.redirect_uri
|
||||||
redirect_uri = parsed_args.get("redirect_uri")
|
|
||||||
|
|
||||||
# check if redirect_uri is valid
|
# check if redirect_uri is valid
|
||||||
if redirect_uri not in oauth_provider_app.redirect_uris:
|
if redirect_uri not in oauth_provider_app.redirect_uris:
|
||||||
|
|
@ -130,33 +149,25 @@ class OAuthServerUserTokenApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@oauth_server_client_id_required
|
@oauth_server_client_id_required
|
||||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||||
parser = (
|
payload = OAuthTokenRequest.model_validate(request.get_json())
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("grant_type", type=str, required=True, location="json")
|
|
||||||
.add_argument("code", type=str, required=False, location="json")
|
|
||||||
.add_argument("client_secret", type=str, required=False, location="json")
|
|
||||||
.add_argument("redirect_uri", type=str, required=False, location="json")
|
|
||||||
.add_argument("refresh_token", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
parsed_args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
grant_type = OAuthGrantType(payload.grant_type)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise BadRequest("invalid grant_type")
|
raise BadRequest("invalid grant_type")
|
||||||
|
|
||||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||||
if not parsed_args["code"]:
|
if not payload.code:
|
||||||
raise BadRequest("code is required")
|
raise BadRequest("code is required")
|
||||||
|
|
||||||
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
|
if payload.client_secret != oauth_provider_app.client_secret:
|
||||||
raise BadRequest("client_secret is invalid")
|
raise BadRequest("client_secret is invalid")
|
||||||
|
|
||||||
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
|
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||||
raise BadRequest("redirect_uri is invalid")
|
raise BadRequest("redirect_uri is invalid")
|
||||||
|
|
||||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||||
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
|
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||||
)
|
)
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
{
|
{
|
||||||
|
|
@ -167,11 +178,11 @@ class OAuthServerUserTokenApi(Resource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||||
if not parsed_args["refresh_token"]:
|
if not payload.refresh_token:
|
||||||
raise BadRequest("refresh_token is required")
|
raise BadRequest("refresh_token is required")
|
||||||
|
|
||||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||||
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
|
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||||
)
|
)
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from werkzeug.exceptions import BadRequest
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -9,6 +11,35 @@ from enums.cloud_plan import CloudPlan
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionQuery(BaseModel):
|
||||||
|
plan: str = Field(..., description="Subscription plan")
|
||||||
|
interval: str = Field(..., description="Billing interval")
|
||||||
|
|
||||||
|
@field_validator("plan")
|
||||||
|
@classmethod
|
||||||
|
def validate_plan(cls, value: str) -> str:
|
||||||
|
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
|
||||||
|
raise ValueError("Invalid plan")
|
||||||
|
return value
|
||||||
|
|
||||||
|
@field_validator("interval")
|
||||||
|
@classmethod
|
||||||
|
def validate_interval(cls, value: str) -> str:
|
||||||
|
if value not in {"month", "year"}:
|
||||||
|
raise ValueError("Invalid interval")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class PartnerTenantsPayload(BaseModel):
|
||||||
|
click_id: str = Field(..., description="Click Id from partner referral link")
|
||||||
|
|
||||||
|
|
||||||
|
for model in (SubscriptionQuery, PartnerTenantsPayload):
|
||||||
|
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/billing/subscription")
|
@console_ns.route("/billing/subscription")
|
||||||
class Subscription(Resource):
|
class Subscription(Resource):
|
||||||
|
|
@ -18,20 +49,9 @@ class Subscription(Resource):
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser = (
|
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument(
|
|
||||||
"plan",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
location="args",
|
|
||||||
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
|
|
||||||
)
|
|
||||||
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
|
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/billing/invoices")
|
@console_ns.route("/billing/invoices")
|
||||||
|
|
@ -65,11 +85,10 @@ class PartnerTenants(Resource):
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def put(self, partner_key: str):
|
def put(self, partner_key: str):
|
||||||
current_user, _ = current_account_with_tenant()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
click_id = args["click_id"]
|
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
|
||||||
|
click_id = args.click_id
|
||||||
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
|
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
|
||||||
except Exception:
|
except Exception:
|
||||||
raise BadRequest("Invalid partner_key")
|
raise BadRequest("Invalid partner_key")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from libs.helper import extract_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
|
|
@ -9,16 +10,28 @@ from .. import console_ns
|
||||||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||||
|
|
||||||
|
|
||||||
|
class ComplianceDownloadQuery(BaseModel):
|
||||||
|
doc_name: str = Field(..., description="Compliance document name")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
ComplianceDownloadQuery.__name__,
|
||||||
|
ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/compliance/download")
|
@console_ns.route("/compliance/download")
|
||||||
class ComplianceApi(Resource):
|
class ComplianceApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__])
|
||||||
|
@console_ns.doc("download_compliance_document")
|
||||||
|
@console_ns.doc(description="Get compliance document download link")
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
|
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
device_info = request.headers.get("User-Agent", "Unknown device")
|
device_info = request.headers.get("User-Agent", "Unknown device")
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
|
@ -35,20 +37,26 @@ recommended_app_list_fields = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args")
|
class RecommendedAppsQuery(BaseModel):
|
||||||
|
language: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
RecommendedAppsQuery.__name__,
|
||||||
|
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/explore/apps")
|
@console_ns.route("/explore/apps")
|
||||||
class RecommendedAppListApi(Resource):
|
class RecommendedAppListApi(Resource):
|
||||||
@console_ns.expect(parser_apps)
|
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(recommended_app_list_fields)
|
@marshal_with(recommended_app_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
# language args
|
# language args
|
||||||
args = parser_apps.parse_args()
|
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
|
language = args.language
|
||||||
language = args.get("language")
|
|
||||||
if language and language in languages:
|
if language and language in languages:
|
||||||
language_prefix = language
|
language_prefix = language
|
||||||
elif current_user and current_user.interface_language:
|
elif current_user and current_user.interface_language:
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from flask import session
|
from flask import session
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import StrLen
|
|
||||||
from models.model import DifySetup
|
from models.model import DifySetup
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
|
|
||||||
|
|
@ -15,6 +15,18 @@ from . import console_ns
|
||||||
from .error import AlreadySetupError, InitValidateFailedError
|
from .error import AlreadySetupError, InitValidateFailedError
|
||||||
from .wraps import only_edition_self_hosted
|
from .wraps import only_edition_self_hosted
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class InitValidatePayload(BaseModel):
|
||||||
|
password: str = Field(..., max_length=30)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
InitValidatePayload.__name__,
|
||||||
|
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/init")
|
@console_ns.route("/init")
|
||||||
class InitValidateAPI(Resource):
|
class InitValidateAPI(Resource):
|
||||||
|
|
@ -37,12 +49,7 @@ class InitValidateAPI(Resource):
|
||||||
|
|
||||||
@console_ns.doc("validate_init_password")
|
@console_ns.doc("validate_init_password")
|
||||||
@console_ns.doc(description="Validate initialization password for self-hosted edition")
|
@console_ns.doc(description="Validate initialization password for self-hosted edition")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[InitValidatePayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"InitValidateRequest",
|
|
||||||
{"password": fields.String(required=True, description="Initialization password", max_length=30)},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
201,
|
201,
|
||||||
"Success",
|
"Success",
|
||||||
|
|
@ -57,8 +64,8 @@ class InitValidateAPI(Resource):
|
||||||
if tenant_count > 0:
|
if tenant_count > 0:
|
||||||
raise AlreadySetupError()
|
raise AlreadySetupError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json")
|
payload = InitValidatePayload.model_validate(console_ns.payload)
|
||||||
input_password = parser.parse_args()["password"]
|
input_password = payload.password
|
||||||
|
|
||||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||||
session["is_init_validated"] = False
|
session["is_init_validated"] = False
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.common import helpers
|
from controllers.common import helpers
|
||||||
|
|
@ -36,17 +37,23 @@ class RemoteFileInfoApi(Resource):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
|
class RemoteFileUploadPayload(BaseModel):
|
||||||
|
url: str = Field(..., description="URL to fetch")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
RemoteFileUploadPayload.__name__,
|
||||||
|
RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/remote-files/upload")
|
@console_ns.route("/remote-files/upload")
|
||||||
class RemoteFileUploadApi(Resource):
|
class RemoteFileUploadApi(Resource):
|
||||||
@console_ns.expect(parser_upload)
|
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
|
||||||
@marshal_with(file_fields_with_signed_url)
|
@marshal_with(file_fields_with_signed_url)
|
||||||
def post(self):
|
def post(self):
|
||||||
args = parser_upload.parse_args()
|
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
||||||
|
url = args.url
|
||||||
url = args["url"]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = ssrf_proxy.head(url=url)
|
resp = ssrf_proxy.head(url=url)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from libs.helper import StrLen, email, extract_remote_ip
|
from libs.helper import EmailStr, extract_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models.model import DifySetup, db
|
from models.model import DifySetup, db
|
||||||
from services.account_service import RegisterService, TenantService
|
from services.account_service import RegisterService, TenantService
|
||||||
|
|
@ -12,6 +13,26 @@ from .error import AlreadySetupError, NotInitValidateError
|
||||||
from .init_validate import get_init_validate_status
|
from .init_validate import get_init_validate_status
|
||||||
from .wraps import only_edition_self_hosted
|
from .wraps import only_edition_self_hosted
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class SetupRequestPayload(BaseModel):
|
||||||
|
email: EmailStr = Field(..., description="Admin email address")
|
||||||
|
name: str = Field(..., max_length=30, description="Admin name (max 30 characters)")
|
||||||
|
password: str = Field(..., description="Admin password")
|
||||||
|
language: str | None = Field(default=None, description="Admin language")
|
||||||
|
|
||||||
|
@field_validator("password")
|
||||||
|
@classmethod
|
||||||
|
def validate_password(cls, value: str) -> str:
|
||||||
|
return valid_password(value)
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
SetupRequestPayload.__name__,
|
||||||
|
SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/setup")
|
@console_ns.route("/setup")
|
||||||
class SetupApi(Resource):
|
class SetupApi(Resource):
|
||||||
|
|
@ -42,17 +63,7 @@ class SetupApi(Resource):
|
||||||
|
|
||||||
@console_ns.doc("setup_system")
|
@console_ns.doc("setup_system")
|
||||||
@console_ns.doc(description="Initialize system setup with admin account")
|
@console_ns.doc(description="Initialize system setup with admin account")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"SetupRequest",
|
|
||||||
{
|
|
||||||
"email": fields.String(required=True, description="Admin email address"),
|
|
||||||
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
|
|
||||||
"password": fields.String(required=True, description="Admin password"),
|
|
||||||
"language": fields.String(required=False, description="Admin language"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
|
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
|
||||||
)
|
)
|
||||||
|
|
@ -72,22 +83,15 @@ class SetupApi(Resource):
|
||||||
if not get_init_validate_status():
|
if not get_init_validate_status():
|
||||||
raise NotInitValidateError()
|
raise NotInitValidateError()
|
||||||
|
|
||||||
parser = (
|
args = SetupRequestPayload.model_validate(console_ns.payload)
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=email, required=True, location="json")
|
|
||||||
.add_argument("name", type=StrLen(30), required=True, location="json")
|
|
||||||
.add_argument("password", type=valid_password, required=True, location="json")
|
|
||||||
.add_argument("language", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# setup
|
# setup
|
||||||
RegisterService.setup(
|
RegisterService.setup(
|
||||||
email=args["email"],
|
email=args.email,
|
||||||
name=args["name"],
|
name=args.name,
|
||||||
password=args["password"],
|
password=args.password,
|
||||||
ip_address=extract_remote_ip(request),
|
ip_address=extract_remote_ip(request),
|
||||||
language=args["language"],
|
language=args.language,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"result": "success"}, 201
|
return {"result": "success"}, 201
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,10 @@ import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
|
||||||
|
|
@ -11,8 +13,14 @@ from . import console_ns
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
parser = reqparse.RequestParser().add_argument(
|
|
||||||
"current_version", type=str, required=True, location="args", help="Current application version"
|
class VersionQuery(BaseModel):
|
||||||
|
current_version: str = Field(..., description="Current application version")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
VersionQuery.__name__,
|
||||||
|
VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,7 +28,7 @@ parser = reqparse.RequestParser().add_argument(
|
||||||
class VersionApi(Resource):
|
class VersionApi(Resource):
|
||||||
@console_ns.doc("check_version_update")
|
@console_ns.doc("check_version_update")
|
||||||
@console_ns.doc(description="Check for application version updates")
|
@console_ns.doc(description="Check for application version updates")
|
||||||
@console_ns.expect(parser)
|
@console_ns.expect(console_ns.models[VersionQuery.__name__])
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Success",
|
"Success",
|
||||||
|
|
@ -37,7 +45,7 @@ class VersionApi(Resource):
|
||||||
)
|
)
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Check for application version updates"""
|
"""Check for application version updates"""
|
||||||
args = parser.parse_args()
|
args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
|
|
@ -57,16 +65,16 @@ class VersionApi(Resource):
|
||||||
try:
|
try:
|
||||||
response = httpx.get(
|
response = httpx.get(
|
||||||
check_update_url,
|
check_update_url,
|
||||||
params={"current_version": args["current_version"]},
|
params={"current_version": args.current_version},
|
||||||
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
||||||
)
|
)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.warning("Check update version error: %s.", str(error))
|
logger.warning("Check update version error: %s.", str(error))
|
||||||
result["version"] = args["current_version"]
|
result["version"] = args.current_version
|
||||||
return result
|
return result
|
||||||
|
|
||||||
content = json.loads(response.content)
|
content = json.loads(response.content)
|
||||||
if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"):
|
if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
|
||||||
result["version"] = content["version"]
|
result["version"] = content["version"]
|
||||||
result["release_date"] = content["releaseDate"]
|
result["release_date"] = content["releaseDate"]
|
||||||
result["release_notes"] = content["releaseNotes"]
|
result["release_notes"] = content["releaseNotes"]
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ from controllers.console.wraps import (
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.member_fields import account_fields
|
from fields.member_fields import account_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import TimestampField, email, extract_remote_ip, timezone
|
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account, AccountIntegrate, InvitationCode
|
from models import Account, AccountIntegrate, InvitationCode
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
|
|
@ -111,14 +111,9 @@ class AccountDeletePayload(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class AccountDeletionFeedbackPayload(BaseModel):
|
class AccountDeletionFeedbackPayload(BaseModel):
|
||||||
email: str
|
email: EmailStr
|
||||||
feedback: str
|
feedback: str
|
||||||
|
|
||||||
@field_validator("email")
|
|
||||||
@classmethod
|
|
||||||
def validate_email(cls, value: str) -> str:
|
|
||||||
return email(value)
|
|
||||||
|
|
||||||
|
|
||||||
class EducationActivatePayload(BaseModel):
|
class EducationActivatePayload(BaseModel):
|
||||||
token: str
|
token: str
|
||||||
|
|
@ -133,45 +128,25 @@ class EducationAutocompleteQuery(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ChangeEmailSendPayload(BaseModel):
|
class ChangeEmailSendPayload(BaseModel):
|
||||||
email: str
|
email: EmailStr
|
||||||
language: str | None = None
|
language: str | None = None
|
||||||
phase: str | None = None
|
phase: str | None = None
|
||||||
token: str | None = None
|
token: str | None = None
|
||||||
|
|
||||||
@field_validator("email")
|
|
||||||
@classmethod
|
|
||||||
def validate_email(cls, value: str) -> str:
|
|
||||||
return email(value)
|
|
||||||
|
|
||||||
|
|
||||||
class ChangeEmailValidityPayload(BaseModel):
|
class ChangeEmailValidityPayload(BaseModel):
|
||||||
email: str
|
email: EmailStr
|
||||||
code: str
|
code: str
|
||||||
token: str
|
token: str
|
||||||
|
|
||||||
@field_validator("email")
|
|
||||||
@classmethod
|
|
||||||
def validate_email(cls, value: str) -> str:
|
|
||||||
return email(value)
|
|
||||||
|
|
||||||
|
|
||||||
class ChangeEmailResetPayload(BaseModel):
|
class ChangeEmailResetPayload(BaseModel):
|
||||||
new_email: str
|
new_email: EmailStr
|
||||||
token: str
|
token: str
|
||||||
|
|
||||||
@field_validator("new_email")
|
|
||||||
@classmethod
|
|
||||||
def validate_email(cls, value: str) -> str:
|
|
||||||
return email(value)
|
|
||||||
|
|
||||||
|
|
||||||
class CheckEmailUniquePayload(BaseModel):
|
class CheckEmailUniquePayload(BaseModel):
|
||||||
email: str
|
email: EmailStr
|
||||||
|
|
||||||
@field_validator("email")
|
|
||||||
@classmethod
|
|
||||||
def validate_email(cls, value: str) -> str:
|
|
||||||
return email(value)
|
|
||||||
|
|
||||||
|
|
||||||
def reg(cls: type[BaseModel]):
|
def reg(cls: type[BaseModel]):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from flask import Response, request
|
from flask import Response, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
|
|
@ -11,6 +12,26 @@ from extensions.ext_database import db
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class FileSignatureQuery(BaseModel):
|
||||||
|
timestamp: str = Field(..., description="Unix timestamp used in the signature")
|
||||||
|
nonce: str = Field(..., description="Random string for signature")
|
||||||
|
sign: str = Field(..., description="HMAC signature")
|
||||||
|
|
||||||
|
|
||||||
|
class FilePreviewQuery(FileSignatureQuery):
|
||||||
|
as_attachment: bool = Field(default=False, description="Whether to download as attachment")
|
||||||
|
|
||||||
|
|
||||||
|
files_ns.schema_model(
|
||||||
|
FileSignatureQuery.__name__, FileSignatureQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
|
)
|
||||||
|
files_ns.schema_model(
|
||||||
|
FilePreviewQuery.__name__, FilePreviewQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@files_ns.route("/<uuid:file_id>/image-preview")
|
@files_ns.route("/<uuid:file_id>/image-preview")
|
||||||
class ImagePreviewApi(Resource):
|
class ImagePreviewApi(Resource):
|
||||||
|
|
@ -36,12 +57,10 @@ class ImagePreviewApi(Resource):
|
||||||
def get(self, file_id):
|
def get(self, file_id):
|
||||||
file_id = str(file_id)
|
file_id = str(file_id)
|
||||||
|
|
||||||
timestamp = request.args.get("timestamp")
|
args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
nonce = request.args.get("nonce")
|
timestamp = args.timestamp
|
||||||
sign = request.args.get("sign")
|
nonce = args.nonce
|
||||||
|
sign = args.sign
|
||||||
if not timestamp or not nonce or not sign:
|
|
||||||
return {"content": "Invalid request."}, 400
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generator, mimetype = FileService(db.engine).get_image_preview(
|
generator, mimetype = FileService(db.engine).get_image_preview(
|
||||||
|
|
@ -80,25 +99,14 @@ class FilePreviewApi(Resource):
|
||||||
def get(self, file_id):
|
def get(self, file_id):
|
||||||
file_id = str(file_id)
|
file_id = str(file_id)
|
||||||
|
|
||||||
parser = (
|
args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("timestamp", type=str, required=True, location="args")
|
|
||||||
.add_argument("nonce", type=str, required=True, location="args")
|
|
||||||
.add_argument("sign", type=str, required=True, location="args")
|
|
||||||
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if not args["timestamp"] or not args["nonce"] or not args["sign"]:
|
|
||||||
return {"content": "Invalid request."}, 400
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
|
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
timestamp=args["timestamp"],
|
timestamp=args.timestamp,
|
||||||
nonce=args["nonce"],
|
nonce=args.nonce,
|
||||||
sign=args["sign"],
|
sign=args.sign,
|
||||||
)
|
)
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
@ -125,7 +133,7 @@ class FilePreviewApi(Resource):
|
||||||
response.headers["Accept-Ranges"] = "bytes"
|
response.headers["Accept-Ranges"] = "bytes"
|
||||||
if upload_file.size > 0:
|
if upload_file.size > 0:
|
||||||
response.headers["Content-Length"] = str(upload_file.size)
|
response.headers["Content-Length"] = str(upload_file.size)
|
||||||
if args["as_attachment"]:
|
if args.as_attachment:
|
||||||
encoded_filename = quote(upload_file.name)
|
encoded_filename = quote(upload_file.name)
|
||||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||||
response.headers["Content-Type"] = "application/octet-stream"
|
response.headers["Content-Type"] = "application/octet-stream"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from flask import Response
|
from flask import Response, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
from controllers.common.errors import UnsupportedFileTypeError
|
from controllers.common.errors import UnsupportedFileTypeError
|
||||||
|
|
@ -10,6 +11,20 @@ from core.tools.signature import verify_tool_file_signature
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from extensions.ext_database import db as global_db
|
from extensions.ext_database import db as global_db
|
||||||
|
|
||||||
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolFileQuery(BaseModel):
|
||||||
|
timestamp: str = Field(..., description="Unix timestamp")
|
||||||
|
nonce: str = Field(..., description="Random nonce")
|
||||||
|
sign: str = Field(..., description="HMAC signature")
|
||||||
|
as_attachment: bool = Field(default=False, description="Download as attachment")
|
||||||
|
|
||||||
|
|
||||||
|
files_ns.schema_model(
|
||||||
|
ToolFileQuery.__name__, ToolFileQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||||
class ToolFileApi(Resource):
|
class ToolFileApi(Resource):
|
||||||
|
|
@ -36,18 +51,8 @@ class ToolFileApi(Resource):
|
||||||
def get(self, file_id, extension):
|
def get(self, file_id, extension):
|
||||||
file_id = str(file_id)
|
file_id = str(file_id)
|
||||||
|
|
||||||
parser = (
|
args = ToolFileQuery.model_validate(request.args.to_dict())
|
||||||
reqparse.RequestParser()
|
if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign):
|
||||||
.add_argument("timestamp", type=str, required=True, location="args")
|
|
||||||
.add_argument("nonce", type=str, required=True, location="args")
|
|
||||||
.add_argument("sign", type=str, required=True, location="args")
|
|
||||||
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
if not verify_tool_file_signature(
|
|
||||||
file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"]
|
|
||||||
):
|
|
||||||
raise Forbidden("Invalid request.")
|
raise Forbidden("Invalid request.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -69,7 +74,7 @@ class ToolFileApi(Resource):
|
||||||
)
|
)
|
||||||
if tool_file.size > 0:
|
if tool_file.size > 0:
|
||||||
response.headers["Content-Length"] = str(tool_file.size)
|
response.headers["Content-Length"] = str(tool_file.size)
|
||||||
if args["as_attachment"]:
|
if args.as_attachment:
|
||||||
encoded_filename = quote(tool_file.name)
|
encoded_filename = quote(tool_file.name)
|
||||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,40 +1,45 @@
|
||||||
from mimetypes import guess_extension
|
from mimetypes import guess_extension
|
||||||
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource
|
||||||
from flask_restx.api import HTTPStatus
|
from flask_restx.api import HTTPStatus
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from werkzeug.datastructures import FileStorage
|
from werkzeug.datastructures import FileStorage
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.common.errors import (
|
|
||||||
FileTooLargeError,
|
|
||||||
UnsupportedFileTypeError,
|
|
||||||
)
|
|
||||||
from controllers.console.wraps import setup_required
|
|
||||||
from controllers.files import files_ns
|
|
||||||
from controllers.inner_api.plugin.wraps import get_user
|
|
||||||
from core.file.helpers import verify_plugin_file_signature
|
from core.file.helpers import verify_plugin_file_signature
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from fields.file_fields import build_file_model
|
from fields.file_fields import build_file_model
|
||||||
|
|
||||||
# Define parser for both documentation and validation
|
from ..common.errors import (
|
||||||
upload_parser = (
|
FileTooLargeError,
|
||||||
reqparse.RequestParser()
|
UnsupportedFileTypeError,
|
||||||
.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload")
|
)
|
||||||
.add_argument(
|
from ..console.wraps import setup_required
|
||||||
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification"
|
from ..files import files_ns
|
||||||
)
|
from ..inner_api.plugin.wraps import get_user
|
||||||
.add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification")
|
|
||||||
.add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation")
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
|
|
||||||
.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
|
|
||||||
|
class PluginUploadQuery(BaseModel):
|
||||||
|
timestamp: str = Field(..., description="Unix timestamp for signature verification")
|
||||||
|
nonce: str = Field(..., description="Random nonce for signature verification")
|
||||||
|
sign: str = Field(..., description="HMAC signature")
|
||||||
|
tenant_id: str = Field(..., description="Tenant identifier")
|
||||||
|
user_id: str | None = Field(default=None, description="User identifier")
|
||||||
|
|
||||||
|
|
||||||
|
files_ns.schema_model(
|
||||||
|
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@files_ns.route("/upload/for-plugin")
|
@files_ns.route("/upload/for-plugin")
|
||||||
class PluginUploadFileApi(Resource):
|
class PluginUploadFileApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@files_ns.expect(upload_parser)
|
@files_ns.expect(files_ns.models[PluginUploadQuery.__name__])
|
||||||
@files_ns.doc("upload_plugin_file")
|
@files_ns.doc("upload_plugin_file")
|
||||||
@files_ns.doc(description="Upload a file for plugin usage with signature verification")
|
@files_ns.doc(description="Upload a file for plugin usage with signature verification")
|
||||||
@files_ns.doc(
|
@files_ns.doc(
|
||||||
|
|
@ -62,15 +67,17 @@ class PluginUploadFileApi(Resource):
|
||||||
FileTooLargeError: File exceeds size limit
|
FileTooLargeError: File exceeds size limit
|
||||||
UnsupportedFileTypeError: File type not supported
|
UnsupportedFileTypeError: File type not supported
|
||||||
"""
|
"""
|
||||||
# Parse and validate all arguments
|
args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = upload_parser.parse_args()
|
|
||||||
|
|
||||||
file: FileStorage = args["file"]
|
file: FileStorage | None = request.files.get("file")
|
||||||
timestamp: str = args["timestamp"]
|
if file is None:
|
||||||
nonce: str = args["nonce"]
|
raise Forbidden("File is required.")
|
||||||
sign: str = args["sign"]
|
|
||||||
tenant_id: str = args["tenant_id"]
|
timestamp = args.timestamp
|
||||||
user_id: str | None = args.get("user_id")
|
nonce = args.nonce
|
||||||
|
sign = args.sign
|
||||||
|
tenant_id = args.tenant_id
|
||||||
|
user_id = args.user_id
|
||||||
user = get_user(tenant_id, user_id)
|
user = get_user(tenant_id, user_id)
|
||||||
|
|
||||||
filename: str | None = file.filename
|
filename: str | None = file.filename
|
||||||
|
|
|
||||||
|
|
@ -256,7 +256,7 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
|
||||||
now = datetime_utils.naive_utc_now()
|
now = datetime_utils.naive_utc_now()
|
||||||
last_update = _get_last_update_timestamp(cache_key)
|
last_update = _get_last_update_timestamp(cache_key)
|
||||||
|
|
||||||
if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS:
|
if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS: # type: ignore
|
||||||
update_values["last_used"] = values.last_used
|
update_values["last_used"] = values.last_used
|
||||||
_set_last_update_timestamp(cache_key, now)
|
_set_last_update_timestamp(cache_key, now)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import ssl
|
import ssl
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import TYPE_CHECKING, Any, Union
|
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
from redis import RedisError
|
from redis import RedisError
|
||||||
|
|
@ -245,7 +245,12 @@ def init_app(app: DifyApp):
|
||||||
app.extensions["redis"] = redis_client
|
app.extensions["redis"] = redis_client
|
||||||
|
|
||||||
|
|
||||||
def redis_fallback(default_return: Any | None = None):
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def redis_fallback(default_return: T | None = None): # type: ignore
|
||||||
"""
|
"""
|
||||||
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
|
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
|
||||||
|
|
||||||
|
|
@ -253,9 +258,9 @@ def redis_fallback(default_return: Any | None = None):
|
||||||
default_return: The value to return when a Redis operation fails. Defaults to None.
|
default_return: The value to return when a Redis operation fails. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable):
|
def decorator(func: Callable[P, R]):
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
except RedisError as e:
|
except RedisError as e:
|
||||||
|
|
|
||||||
|
|
@ -10,12 +10,13 @@ import uuid
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
|
||||||
from zoneinfo import available_timezones
|
from zoneinfo import available_timezones
|
||||||
|
|
||||||
from flask import Response, stream_with_context
|
from flask import Response, stream_with_context
|
||||||
from flask_restx import fields
|
from flask_restx import fields
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from pydantic.functional_validators import AfterValidator
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||||
|
|
@ -103,6 +104,9 @@ def email(email):
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
|
||||||
|
|
||||||
|
EmailStr = Annotated[str, AfterValidator(email)]
|
||||||
|
|
||||||
|
|
||||||
def uuid_value(value):
|
def uuid_value(value):
|
||||||
if value == "":
|
if value == "":
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
project-includes = ["."]
|
||||||
|
project-excludes = [
|
||||||
|
"tests/",
|
||||||
|
".venv",
|
||||||
|
"migrations/",
|
||||||
|
"core/rag",
|
||||||
|
]
|
||||||
|
python-platform = "linux"
|
||||||
|
python-version = "3.11.0"
|
||||||
|
infer-with-first-use = false
|
||||||
|
|
@ -1259,7 +1259,7 @@ class RegisterService:
|
||||||
return f"member_invite:token:{token}"
|
return f"member_invite:token:{token}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup(cls, email: str, name: str, password: str, ip_address: str, language: str):
|
def setup(cls, email: str, name: str, password: str, ip_address: str, language: str | None):
|
||||||
"""
|
"""
|
||||||
Setup dify
|
Setup dify
|
||||||
|
|
||||||
|
|
@ -1267,6 +1267,7 @@ class RegisterService:
|
||||||
:param name: username
|
:param name: username
|
||||||
:param password: password
|
:param password: password
|
||||||
:param ip_address: ip address
|
:param ip_address: ip address
|
||||||
|
:param language: language
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
account = AccountService.create_account(
|
account = AccountService.create_account(
|
||||||
|
|
@ -1414,7 +1415,7 @@ class RegisterService:
|
||||||
return data is not None
|
return data is not None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def revoke_token(cls, workspace_id: str, email: str, token: str):
|
def revoke_token(cls, workspace_id: str | None, email: str | None, token: str):
|
||||||
if workspace_id and email:
|
if workspace_id and email:
|
||||||
email_hash = sha256(email.encode()).hexdigest()
|
email_hash = sha256(email.encode()).hexdigest()
|
||||||
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
|
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
|
||||||
|
|
@ -1423,7 +1424,9 @@ class RegisterService:
|
||||||
redis_client.delete(cls._get_invitation_token_key(token))
|
redis_client.delete(cls._get_invitation_token_key(token))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None:
|
def get_invitation_if_token_valid(
|
||||||
|
cls, workspace_id: str | None, email: str | None, token: str
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
|
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
|
||||||
if not invitation_data:
|
if not invitation_data:
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue