refactor: port reqparse to Pydantic model (#28949)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2025-12-05 13:05:53 +09:00 committed by GitHub
parent 6325dcf8aa
commit 7396eba1af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 900 additions and 783 deletions

View File

@ -3,7 +3,8 @@ from functools import wraps
from typing import ParamSpec, TypeVar from typing import ParamSpec, TypeVar
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
@ -18,6 +19,30 @@ from extensions.ext_database import db
from libs.token import extract_access_token from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp from models.model import App, InstalledApp, RecommendedApp
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InsertExploreAppPayload(BaseModel):
app_id: str = Field(...)
desc: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
@field_validator("language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def admin_required(view: Callable[P, R]): def admin_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
@ -40,59 +65,34 @@ def admin_required(view: Callable[P, R]):
class InsertExploreAppListApi(Resource): class InsertExploreAppListApi(Resource):
@console_ns.doc("insert_explore_app") @console_ns.doc("insert_explore_app")
@console_ns.doc(description="Insert or update an app in the explore list") @console_ns.doc(description="Insert or update an app in the explore list")
@console_ns.expect( @console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
console_ns.model(
"InsertExploreAppRequest",
{
"app_id": fields.String(required=True, description="Application ID"),
"desc": fields.String(description="App description"),
"copyright": fields.String(description="Copyright information"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"language": fields.String(required=True, description="Language code"),
"category": fields.String(required=True, description="App category"),
"position": fields.Integer(required=True, description="Display position"),
},
)
)
@console_ns.response(200, "App updated successfully") @console_ns.response(200, "App updated successfully")
@console_ns.response(201, "App inserted successfully") @console_ns.response(201, "App inserted successfully")
@console_ns.response(404, "App not found") @console_ns.response(404, "App not found")
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def post(self): def post(self):
parser = ( payload = InsertExploreAppPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
.add_argument("desc", type=str, location="json")
.add_argument("copyright", type=str, location="json")
.add_argument("privacy_policy", type=str, location="json")
.add_argument("custom_disclaimer", type=str, location="json")
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("position", type=int, required=True, nullable=False, location="json")
)
args = parser.parse_args()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
if not app: if not app:
raise NotFound(f"App '{args['app_id']}' is not found") raise NotFound(f"App '{payload.app_id}' is not found")
site = app.site site = app.site
if not site: if not site:
desc = args["desc"] or "" desc = payload.desc or ""
copy_right = args["copyright"] or "" copy_right = payload.copyright or ""
privacy_policy = args["privacy_policy"] or "" privacy_policy = payload.privacy_policy or ""
custom_disclaimer = args["custom_disclaimer"] or "" custom_disclaimer = payload.custom_disclaimer or ""
else: else:
desc = site.description or args["desc"] or "" desc = site.description or payload.desc or ""
copy_right = site.copyright or args["copyright"] or "" copy_right = site.copyright or payload.copyright or ""
privacy_policy = site.privacy_policy or args["privacy_policy"] or "" privacy_policy = site.privacy_policy or payload.privacy_policy or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
with Session(db.engine) as session: with Session(db.engine) as session:
recommended_app = session.execute( recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]) select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
).scalar_one_or_none() ).scalar_one_or_none()
if not recommended_app: if not recommended_app:
@ -102,9 +102,9 @@ class InsertExploreAppListApi(Resource):
copyright=copy_right, copyright=copy_right,
privacy_policy=privacy_policy, privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer, custom_disclaimer=custom_disclaimer,
language=args["language"], language=payload.language,
category=args["category"], category=payload.category,
position=args["position"], position=payload.position,
) )
db.session.add(recommended_app) db.session.add(recommended_app)
@ -118,9 +118,9 @@ class InsertExploreAppListApi(Resource):
recommended_app.copyright = copy_right recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = args["language"] recommended_app.language = payload.language
recommended_app.category = args["category"] recommended_app.category = payload.category
recommended_app.position = args["position"] recommended_app.position = payload.position
app.is_public = True app.is_public = True

View File

@ -1,4 +1,6 @@
from flask_restx import Resource, fields, reqparse from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -8,10 +10,21 @@ from libs.login import login_required
from models.model import AppMode from models.model import AppMode
from services.agent_service import AgentService from services.agent_service import AgentService
parser = ( DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID") class AgentLogQuery(BaseModel):
message_id: str = Field(..., description="Message UUID")
conversation_id: str = Field(..., description="Conversation UUID")
@field_validator("message_id", "conversation_id")
@classmethod
def validate_uuid(cls, value: str) -> str:
return uuid_value(value)
console_ns.schema_model(
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
) )
@ -20,7 +33,7 @@ class AgentLogApi(Resource):
@console_ns.doc("get_agent_logs") @console_ns.doc("get_agent_logs")
@console_ns.doc(description="Get agent execution logs for an application") @console_ns.doc(description="Get agent execution logs for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[AgentLogQuery.__name__])
@console_ns.response( @console_ns.response(
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")) 200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
) )
@ -31,6 +44,6 @@ class AgentLogApi(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model): def get(self, app_model):
"""Get agent logs""" """Get agent logs"""
args = parser.parse_args() args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)

View File

@ -1,7 +1,8 @@
from typing import Literal from typing import Any, Literal
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import console_ns from controllers.console import console_ns
@ -21,22 +22,79 @@ from libs.helper import uuid_value
from libs.login import login_required from libs.login import login_required
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AnnotationReplyPayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold for annotation matching")
embedding_provider_name: str = Field(..., description="Embedding provider name")
embedding_model_name: str = Field(..., description="Embedding model name")
class AnnotationSettingUpdatePayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold")
class AnnotationListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, description="Page size")
keyword: str = Field(default="", description="Search keyword")
class CreateAnnotationPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
question: str | None = Field(default=None, description="Question text")
answer: str | None = Field(default=None, description="Answer text")
content: str | None = Field(default=None, description="Content text")
annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class UpdateAnnotationPayload(BaseModel):
question: str | None = None
answer: str | None = None
content: str | None = None
annotation_reply: dict[str, Any] | None = None
class AnnotationReplyStatusQuery(BaseModel):
action: Literal["enable", "disable"]
class AnnotationFilePayload(BaseModel):
message_id: str = Field(..., description="Message ID")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
def reg(model: type[BaseModel]) -> None:
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AnnotationReplyPayload)
reg(AnnotationSettingUpdatePayload)
reg(AnnotationListQuery)
reg(CreateAnnotationPayload)
reg(UpdateAnnotationPayload)
reg(AnnotationReplyStatusQuery)
reg(AnnotationFilePayload)
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>") @console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource): class AnnotationReplyActionApi(Resource):
@console_ns.doc("annotation_reply_action") @console_ns.doc("annotation_reply_action")
@console_ns.doc(description="Enable or disable annotation reply for an app") @console_ns.doc(description="Enable or disable annotation reply for an app")
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"}) @console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
@console_ns.expect( @console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__])
console_ns.model(
"AnnotationReplyActionRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
},
)
)
@console_ns.response(200, "Action completed successfully") @console_ns.response(200, "Action completed successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -46,15 +104,9 @@ class AnnotationReplyActionApi(Resource):
@edit_permission_required @edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]): def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id) app_id = str(app_id)
parser = ( args = AnnotationReplyPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("score_threshold", required=True, type=float, location="json")
.add_argument("embedding_provider_name", required=True, type=str, location="json")
.add_argument("embedding_model_name", required=True, type=str, location="json")
)
args = parser.parse_args()
if action == "enable": if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_id) result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
elif action == "disable": elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id) result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200 return result, 200
@ -82,16 +134,7 @@ class AppAnnotationSettingUpdateApi(Resource):
@console_ns.doc("update_annotation_setting") @console_ns.doc("update_annotation_setting")
@console_ns.doc(description="Update annotation settings for an app") @console_ns.doc(description="Update annotation settings for an app")
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"}) @console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__])
console_ns.model(
"AnnotationSettingUpdateRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
"embedding_model_name": fields.String(required=True, description="Embedding model"),
},
)
)
@console_ns.response(200, "Settings updated successfully") @console_ns.response(200, "Settings updated successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -102,10 +145,9 @@ class AppAnnotationSettingUpdateApi(Resource):
app_id = str(app_id) app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id) annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json") args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
return result, 200 return result, 200
@ -142,12 +184,7 @@ class AnnotationApi(Resource):
@console_ns.doc("list_annotations") @console_ns.doc("list_annotations")
@console_ns.doc(description="Get annotations for an app with pagination") @console_ns.doc(description="Get annotations for an app with pagination")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AnnotationListQuery.__name__])
console_ns.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size")
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
)
@console_ns.response(200, "Annotations retrieved successfully") @console_ns.response(200, "Annotations retrieved successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -155,9 +192,10 @@ class AnnotationApi(Resource):
@account_initialization_required @account_initialization_required
@edit_permission_required @edit_permission_required
def get(self, app_id): def get(self, app_id):
page = request.args.get("page", default=1, type=int) args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
limit = request.args.get("limit", default=20, type=int) page = args.page
keyword = request.args.get("keyword", default="", type=str) limit = args.limit
keyword = args.keyword
app_id = str(app_id) app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
@ -173,18 +211,7 @@ class AnnotationApi(Resource):
@console_ns.doc("create_annotation") @console_ns.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app") @console_ns.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
console_ns.model(
"CreateAnnotationRequest",
{
"message_id": fields.String(description="Message ID (optional)"),
"question": fields.String(description="Question text (required when message_id not provided)"),
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
"content": fields.String(description="Content text (use 'answer' or 'content')"),
"annotation_reply": fields.Raw(description="Annotation reply data"),
},
)
)
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns)) @console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -195,16 +222,9 @@ class AnnotationApi(Resource):
@edit_permission_required @edit_permission_required
def post(self, app_id): def post(self, app_id):
app_id = str(app_id) app_id = str(app_id)
parser = ( args = CreateAnnotationPayload.model_validate(console_ns.payload)
reqparse.RequestParser() data = args.model_dump(exclude_none=True)
.add_argument("message_id", required=False, type=uuid_value, location="json") annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
.add_argument("question", required=False, type=str, location="json")
.add_argument("answer", required=False, type=str, location="json")
.add_argument("content", required=False, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
return annotation return annotation
@setup_required @setup_required
@ -256,13 +276,6 @@ class AnnotationExportApi(Resource):
return response, 200 return response, 200
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>") @console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource): class AnnotationUpdateDeleteApi(Resource):
@console_ns.doc("update_delete_annotation") @console_ns.doc("update_delete_annotation")
@ -271,7 +284,7 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns)) @console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
@console_ns.response(204, "Annotation deleted successfully") @console_ns.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.expect(parser) @console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -281,8 +294,10 @@ class AnnotationUpdateDeleteApi(Resource):
def post(self, app_id, annotation_id): def post(self, app_id, annotation_id):
app_id = str(app_id) app_id = str(app_id)
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
args = parser.parse_args() args = UpdateAnnotationPayload.model_validate(console_ns.payload)
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
return annotation return annotation
@setup_required @setup_required

View File

@ -1,4 +1,5 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -35,23 +36,29 @@ app_import_check_dependencies_model = console_ns.model(
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy "AppImportCheckDependencies", app_import_check_dependencies_fields_copy
) )
parser = ( DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json") class AppImportPayload(BaseModel):
.add_argument("yaml_url", type=str, location="json") mode: str = Field(..., description="Import mode")
.add_argument("name", type=str, location="json") yaml_content: str | None = None
.add_argument("description", type=str, location="json") yaml_url: str | None = None
.add_argument("icon_type", type=str, location="json") name: str | None = None
.add_argument("icon", type=str, location="json") description: str | None = None
.add_argument("icon_background", type=str, location="json") icon_type: str | None = None
.add_argument("app_id", type=str, location="json") icon: str | None = None
icon_background: str | None = None
app_id: str | None = None
console_ns.schema_model(
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
) )
@console_ns.route("/apps/imports") @console_ns.route("/apps/imports")
class AppImportApi(Resource): class AppImportApi(Resource):
@console_ns.expect(parser) @console_ns.expect(console_ns.models[AppImportPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -61,7 +68,7 @@ class AppImportApi(Resource):
def post(self): def post(self):
# Check user role first # Check user role first
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser.parse_args() args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
@ -70,15 +77,15 @@ class AppImportApi(Resource):
account = current_user account = current_user
result = import_service.import_app( result = import_service.import_app(
account=account, account=account,
import_mode=args["mode"], import_mode=args.mode,
yaml_content=args.get("yaml_content"), yaml_content=args.yaml_content,
yaml_url=args.get("yaml_url"), yaml_url=args.yaml_url,
name=args.get("name"), name=args.name,
description=args.get("description"), description=args.description,
icon_type=args.get("icon_type"), icon_type=args.icon_type,
icon=args.get("icon"), icon=args.icon,
icon_background=args.get("icon_background"), icon_background=args.icon_background,
app_id=args.get("app_id"), app_id=args.app_id,
) )
session.commit() session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:

View File

@ -1,7 +1,8 @@
import logging import logging
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
@ -32,6 +33,27 @@ from services.errors.audio import (
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TextToSpeechPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
text: str = Field(..., description="Text to convert")
voice: str | None = Field(default=None, description="Voice name")
streaming: bool | None = Field(default=None, description="Whether to stream audio")
class TextToSpeechVoiceQuery(BaseModel):
language: str = Field(..., description="Language code")
console_ns.schema_model(
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechVoiceQuery.__name__,
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text") @console_ns.route("/apps/<uuid:app_id>/audio-to-text")
@ -92,17 +114,7 @@ class ChatMessageTextApi(Resource):
@console_ns.doc("chat_message_text_to_speech") @console_ns.doc("chat_message_text_to_speech")
@console_ns.doc(description="Convert text to speech for chat messages") @console_ns.doc(description="Convert text to speech for chat messages")
@console_ns.doc(params={"app_id": "App ID"}) @console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TextToSpeechPayload.__name__])
console_ns.model(
"TextToSpeechRequest",
{
"message_id": fields.String(description="Message ID"),
"text": fields.String(required=True, description="Text to convert to speech"),
"voice": fields.String(description="Voice to use for TTS"),
"streaming": fields.Boolean(description="Whether to stream the audio"),
},
)
)
@console_ns.response(200, "Text to speech conversion successful") @console_ns.response(200, "Text to speech conversion successful")
@console_ns.response(400, "Bad request - Invalid parameters") @console_ns.response(400, "Bad request - Invalid parameters")
@get_app_model @get_app_model
@ -111,21 +123,14 @@ class ChatMessageTextApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, app_model: App): def post(self, app_model: App):
try: try:
parser = ( payload = TextToSpeechPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("message_id", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True app_model=app_model,
text=payload.text,
voice=payload.voice,
message_id=payload.message_id,
is_draft=True,
) )
return response return response
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
@ -159,9 +164,7 @@ class TextModesApi(Resource):
@console_ns.doc("get_text_to_speech_voices") @console_ns.doc("get_text_to_speech_voices")
@console_ns.doc(description="Get available TTS voices for a specific language") @console_ns.doc(description="Get available TTS voices for a specific language")
@console_ns.doc(params={"app_id": "App ID"}) @console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__])
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
)
@console_ns.response( @console_ns.response(
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices")) 200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
) )
@ -172,12 +175,11 @@ class TextModesApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
try: try:
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args") args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
response = AudioService.transcript_tts_voices( response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
language=args["language"], language=args.language,
) )
return response return response

View File

@ -1,7 +1,8 @@
import json import json
from enum import StrEnum from enum import StrEnum
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import console_ns from controllers.console import console_ns
@ -12,6 +13,8 @@ from fields.app_fields import app_server_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer from models.model import AppMCPServer
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register model for flask_restx to avoid dict type issues in Swagger # Register model for flask_restx to avoid dict type issues in Swagger
app_server_model = console_ns.model("AppServer", app_server_fields) app_server_model = console_ns.model("AppServer", app_server_fields)
@ -21,6 +24,22 @@ class AppMCPServerStatus(StrEnum):
INACTIVE = "inactive" INACTIVE = "inactive"
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
class MCPServerUpdatePayload(BaseModel):
id: str = Field(..., description="Server ID")
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
status: str | None = Field(default=None, description="Server status")
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/apps/<uuid:app_id>/server") @console_ns.route("/apps/<uuid:app_id>/server")
class AppMCPServerController(Resource): class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server") @console_ns.doc("get_app_mcp_server")
@ -39,15 +58,7 @@ class AppMCPServerController(Resource):
@console_ns.doc("create_app_mcp_server") @console_ns.doc("create_app_mcp_server")
@console_ns.doc(description="Create MCP server configuration for an application") @console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
console_ns.model(
"MCPServerCreateRequest",
{
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
},
)
)
@console_ns.response(201, "MCP server configuration created successfully", app_server_model) @console_ns.response(201, "MCP server configuration created successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@account_initialization_required @account_initialization_required
@ -58,21 +69,16 @@ class AppMCPServerController(Resource):
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
parser = ( payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
)
args = parser.parse_args()
description = args.get("description") description = payload.description
if not description: if not description:
description = app_model.description or "" description = app_model.description or ""
server = AppMCPServer( server = AppMCPServer(
name=app_model.name, name=app_model.name,
description=description, description=description,
parameters=json.dumps(args["parameters"], ensure_ascii=False), parameters=json.dumps(payload.parameters, ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE, status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id, app_id=app_model.id,
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
@ -85,17 +91,7 @@ class AppMCPServerController(Resource):
@console_ns.doc("update_app_mcp_server") @console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application") @console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
console_ns.model(
"MCPServerUpdateRequest",
{
"id": fields.String(required=True, description="Server ID"),
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
"status": fields.String(description="Server status"),
},
)
)
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model) @console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found") @console_ns.response(404, "Server not found")
@ -106,19 +102,12 @@ class AppMCPServerController(Resource):
@marshal_with(app_server_model) @marshal_with(app_server_model)
@edit_permission_required @edit_permission_required
def put(self, app_model): def put(self, app_model):
parser = ( payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
.add_argument("id", type=str, required=True, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
.add_argument("status", type=str, required=False, location="json")
)
args = parser.parse_args()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
if not server: if not server:
raise NotFound() raise NotFound()
description = args.get("description") description = payload.description
if description is None: if description is None:
pass pass
elif not description: elif not description:
@ -126,11 +115,11 @@ class AppMCPServerController(Resource):
else: else:
server.description = description server.description = description
server.parameters = json.dumps(args["parameters"], ensure_ascii=False) server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
if args["status"]: if payload.status:
if args["status"] not in [status.value for status in AppMCPServerStatus]: if payload.status not in [status.value for status in AppMCPServerStatus]:
raise ValueError("Invalid status") raise ValueError("Invalid status")
server.status = args["status"] server.status = payload.status
db.session.commit() db.session.commit()
return server return server

View File

@ -1,4 +1,8 @@
from flask_restx import Resource, fields, reqparse from typing import Any
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from controllers.console import console_ns from controllers.console import console_ns
@ -7,6 +11,26 @@ from controllers.console.wraps import account_initialization_required, setup_req
from libs.login import login_required from libs.login import login_required
from services.ops_service import OpsService from services.ops_service import OpsService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TraceProviderQuery(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
class TraceConfigPayload(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
console_ns.schema_model(
TraceProviderQuery.__name__,
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/<uuid:app_id>/trace-config") @console_ns.route("/apps/<uuid:app_id>/trace-config")
class TraceAppConfigApi(Resource): class TraceAppConfigApi(Resource):
@ -17,11 +41,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("get_trace_app_config") @console_ns.doc("get_trace_app_config")
@console_ns.doc(description="Get tracing configuration for an application") @console_ns.doc(description="Get tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@console_ns.response( @console_ns.response(
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data") 200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
) )
@ -30,11 +50,10 @@ class TraceAppConfigApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args") args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
try: try:
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not trace_config: if not trace_config:
return {"has_not_configured": True} return {"has_not_configured": True}
return trace_config return trace_config
@ -44,15 +63,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("create_trace_app_config") @console_ns.doc("create_trace_app_config")
@console_ns.doc(description="Create a new tracing configuration for an application") @console_ns.doc(description="Create a new tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
console_ns.model(
"TraceConfigCreateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
},
)
)
@console_ns.response( @console_ns.response(
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data") 201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
) )
@ -62,16 +73,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, app_id): def post(self, app_id):
"""Create a new trace app configuration""" """Create a new trace app configuration"""
parser = ( args = TraceConfigPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
try: try:
result = OpsService.create_tracing_app_config( result = OpsService.create_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
) )
if not result: if not result:
raise TracingConfigIsExist() raise TracingConfigIsExist()
@ -84,15 +90,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("update_trace_app_config") @console_ns.doc("update_trace_app_config")
@console_ns.doc(description="Update an existing tracing configuration for an application") @console_ns.doc(description="Update an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
console_ns.model(
"TraceConfigUpdateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
},
)
)
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response")) @console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
@console_ns.response(400, "Invalid request parameters or configuration not found") @console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required @setup_required
@ -100,16 +98,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def patch(self, app_id): def patch(self, app_id):
"""Update an existing trace app configuration""" """Update an existing trace app configuration"""
parser = ( args = TraceConfigPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
try: try:
result = OpsService.update_tracing_app_config( result = OpsService.update_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
) )
if not result: if not result:
raise TracingConfigNotExist() raise TracingConfigNotExist()
@ -120,11 +113,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("delete_trace_app_config") @console_ns.doc("delete_trace_app_config")
@console_ns.doc(description="Delete an existing tracing configuration for an application") @console_ns.doc(description="Delete an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@console_ns.response(204, "Tracing configuration deleted successfully") @console_ns.response(204, "Tracing configuration deleted successfully")
@console_ns.response(400, "Invalid request parameters or configuration not found") @console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required @setup_required
@ -132,11 +121,10 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def delete(self, app_id): def delete(self, app_id):
"""Delete an existing trace app configuration""" """Delete an existing trace app configuration"""
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args") args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
try: try:
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not result: if not result:
raise TracingConfigNotExist() raise TracingConfigNotExist()
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -1,4 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from constants.languages import supported_language from constants.languages import supported_language
@ -16,69 +19,50 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Site from models import Site
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppSiteUpdatePayload(BaseModel):
title: str | None = Field(default=None)
icon_type: str | None = Field(default=None)
icon: str | None = Field(default=None)
icon_background: str | None = Field(default=None)
description: str | None = Field(default=None)
default_language: str | None = Field(default=None)
chat_color_theme: str | None = Field(default=None)
chat_color_theme_inverted: bool | None = Field(default=None)
customize_domain: str | None = Field(default=None)
copyright: str | None = Field(default=None)
privacy_policy: str | None = Field(default=None)
custom_disclaimer: str | None = Field(default=None)
customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None)
prompt_public: bool | None = Field(default=None)
show_workflow_steps: bool | None = Field(default=None)
use_icon_as_answer_icon: bool | None = Field(default=None)
@field_validator("default_language")
@classmethod
def validate_language(cls, value: str | None) -> str | None:
if value is None:
return value
return supported_language(value)
console_ns.schema_model(
AppSiteUpdatePayload.__name__,
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register model for flask_restx to avoid dict type issues in Swagger # Register model for flask_restx to avoid dict type issues in Swagger
app_site_model = console_ns.model("AppSite", app_site_fields) app_site_model = console_ns.model("AppSite", app_site_fields)
def parse_app_site_args():
parser = (
reqparse.RequestParser()
.add_argument("title", type=str, required=False, location="json")
.add_argument("icon_type", type=str, required=False, location="json")
.add_argument("icon", type=str, required=False, location="json")
.add_argument("icon_background", type=str, required=False, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("default_language", type=supported_language, required=False, location="json")
.add_argument("chat_color_theme", type=str, required=False, location="json")
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
.add_argument("customize_domain", type=str, required=False, location="json")
.add_argument("copyright", type=str, required=False, location="json")
.add_argument("privacy_policy", type=str, required=False, location="json")
.add_argument("custom_disclaimer", type=str, required=False, location="json")
.add_argument(
"customize_token_strategy",
type=str,
choices=["must", "allow", "not_allow"],
required=False,
location="json",
)
.add_argument("prompt_public", type=bool, required=False, location="json")
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
)
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/site") @console_ns.route("/apps/<uuid:app_id>/site")
class AppSite(Resource): class AppSite(Resource):
@console_ns.doc("update_app_site") @console_ns.doc("update_app_site")
@console_ns.doc(description="Update application site configuration") @console_ns.doc(description="Update application site configuration")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
console_ns.model(
"AppSiteRequest",
{
"title": fields.String(description="Site title"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"description": fields.String(description="Site description"),
"default_language": fields.String(description="Default language"),
"chat_color_theme": fields.String(description="Chat color theme"),
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
"customize_domain": fields.String(description="Custom domain"),
"copyright": fields.String(description="Copyright text"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"customize_token_strategy": fields.String(
enum=["must", "allow", "not_allow"], description="Token strategy"
),
"prompt_public": fields.Boolean(description="Make prompt public"),
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
},
)
)
@console_ns.response(200, "Site configuration updated successfully", app_site_model) @console_ns.response(200, "Site configuration updated successfully", app_site_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "App not found") @console_ns.response(404, "App not found")
@ -89,7 +73,7 @@ class AppSite(Resource):
@get_app_model @get_app_model
@marshal_with(app_site_model) @marshal_with(app_site_model)
def post(self, app_model): def post(self, app_model):
args = parse_app_site_args() args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site: if not site:
@ -113,7 +97,7 @@ class AppSite(Resource):
"show_workflow_steps", "show_workflow_steps",
"use_icon_as_answer_icon", "use_icon_as_answer_icon",
]: ]:
value = args.get(attr_name) value = getattr(args, attr_name)
if value is not None: if value is not None:
setattr(site, attr_name, value) setattr(site, attr_name, value)

View File

@ -1,10 +1,11 @@
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import NoReturn, ParamSpec, TypeVar from typing import Any, NoReturn, ParamSpec, TypeVar
from flask import Response from flask import Response, request
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import console_ns from controllers.console import console_ns
@ -29,6 +30,27 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowDraftVariableListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=100_000, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Items per page")
class WorkflowDraftVariableUpdatePayload(BaseModel):
name: str | None = Field(default=None, description="Variable name")
value: Any | None = Field(default=None, description="Variable value")
console_ns.schema_model(
WorkflowDraftVariableListQuery.__name__,
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkflowDraftVariableUpdatePayload.__name__,
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _convert_values_to_json_serializable_object(value: Segment): def _convert_values_to_json_serializable_object(value: Segment):
@ -57,22 +79,6 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
return _convert_values_to_json_serializable_object(value) return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = (
reqparse.RequestParser()
.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
return parser
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type value_type = workflow_draft_var.value_type
return value_type.exposed_type().value return value_type.exposed_type().value
@ -201,7 +207,7 @@ def _api_prerequisite(f: Callable[P, R]):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
class WorkflowVariableCollectionApi(Resource): class WorkflowVariableCollectionApi(Resource):
@console_ns.expect(_create_pagination_parser()) @console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
@console_ns.doc("get_workflow_variables") @console_ns.doc("get_workflow_variables")
@console_ns.doc(description="Get draft workflow variables") @console_ns.doc(description="Get draft workflow variables")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@ -215,8 +221,7 @@ class WorkflowVariableCollectionApi(Resource):
""" """
Get draft workflow Get draft workflow
""" """
parser = _create_pagination_parser() args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
# fetch draft workflow by app_model # fetch draft workflow by app_model
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -323,15 +328,7 @@ class VariableApi(Resource):
@console_ns.doc("update_variable") @console_ns.doc("update_variable")
@console_ns.doc(description="Update a workflow variable") @console_ns.doc(description="Update a workflow variable")
@console_ns.expect( @console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
console_ns.model(
"UpdateVariableRequest",
{
"name": fields.String(description="Variable name"),
"value": fields.Raw(description="Variable value"),
},
)
)
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model) @console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found") @console_ns.response(404, "Variable not found")
@_api_prerequisite @_api_prerequisite
@ -358,16 +355,10 @@ class VariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# } # }
parser = (
reqparse.RequestParser()
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=db.session(), session=db.session(),
) )
args = parser.parse_args(strict=True) args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable = draft_var_srv.get_variable(variable_id=variable_id) variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None: if variable is None:
@ -375,8 +366,8 @@ class VariableApi(Resource):
if variable.app_id != app_model.id: if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}") raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None) new_name = args_model.name
raw_value = args.get(self._PATCH_VALUE_FIELD, None) raw_value = args_model.value
if new_name is None and raw_value is None: if new_name is None and raw_value is None:
return variable return variable

View File

@ -1,28 +1,53 @@
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone from libs.helper import EmailStr, extract_remote_ip, timezone
from models import AccountStatus from models import AccountStatus
from services.account_service import AccountService, RegisterService from services.account_service import AccountService, RegisterService
active_check_parser = ( DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address") class ActivateCheckQuery(BaseModel):
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token") workspace_id: str | None = Field(default=None)
) email: EmailStr | None = Field(default=None)
token: str
class ActivatePayload(BaseModel):
workspace_id: str | None = Field(default=None)
email: EmailStr | None = Field(default=None)
token: str
name: str = Field(..., max_length=30)
interface_language: str = Field(...)
timezone: str = Field(...)
@field_validator("interface_language")
@classmethod
def validate_lang(cls, value: str) -> str:
return supported_language(value)
@field_validator("timezone")
@classmethod
def validate_tz(cls, value: str) -> str:
return timezone(value)
for model in (ActivateCheckQuery, ActivatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/activate/check") @console_ns.route("/activate/check")
class ActivateCheckApi(Resource): class ActivateCheckApi(Resource):
@console_ns.doc("check_activation_token") @console_ns.doc("check_activation_token")
@console_ns.doc(description="Check if activation token is valid") @console_ns.doc(description="Check if activation token is valid")
@console_ns.expect(active_check_parser) @console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Success", "Success",
@ -35,11 +60,11 @@ class ActivateCheckApi(Resource):
), ),
) )
def get(self): def get(self):
args = active_check_parser.parse_args() args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args["workspace_id"] workspaceId = args.workspace_id
reg_email = args["email"] reg_email = args.email
token = args["token"] token = args.token
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation: if invitation:
@ -56,22 +81,11 @@ class ActivateCheckApi(Resource):
return {"is_valid": False} return {"is_valid": False}
active_parser = (
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
.add_argument("email", type=email, required=False, nullable=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
)
@console_ns.route("/activate") @console_ns.route("/activate")
class ActivateApi(Resource): class ActivateApi(Resource):
@console_ns.doc("activate_account") @console_ns.doc("activate_account")
@console_ns.doc(description="Activate account with invitation token") @console_ns.doc(description="Activate account with invitation token")
@console_ns.expect(active_parser) @console_ns.expect(console_ns.models[ActivatePayload.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Account activated successfully", "Account activated successfully",
@ -85,19 +99,19 @@ class ActivateApi(Resource):
) )
@console_ns.response(400, "Already activated or invalid token") @console_ns.response(400, "Already activated or invalid token")
def post(self): def post(self):
args = active_parser.parse_args() args = ActivatePayload.model_validate(console_ns.payload)
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
if invitation is None: if invitation is None:
raise AlreadyActivateError() raise AlreadyActivateError()
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"]) RegisterService.revoke_token(args.workspace_id, args.email, args.token)
account = invitation["account"] account = invitation["account"]
account.name = args["name"] account.name = args.name
account.interface_language = args["interface_language"] account.interface_language = args.interface_language
account.timezone = args["timezone"] account.timezone = args.timezone
account.interface_theme = "light" account.interface_theme = "light"
account.status = AccountStatus.ACTIVE account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now() account.initialized_at = naive_utc_now()

View File

@ -1,12 +1,26 @@
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required from .. import console_ns
from ..auth.error import ApiKeyAuthFailedError
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ApiKeyAuthBindingPayload(BaseModel):
category: str = Field(...)
provider: str = Field(...)
credentials: dict = Field(...)
console_ns.schema_model(
ApiKeyAuthBindingPayload.__name__,
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/api-key-auth/data-source") @console_ns.route("/api-key-auth/data-source")
@ -40,19 +54,15 @@ class ApiKeyAuthDataSourceBinding(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@is_admin_or_owner_required @is_admin_or_owner_required
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
def post(self): def post(self):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
parser = ( payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
reqparse.RequestParser() data = payload.model_dump()
.add_argument("category", type=str, required=True, nullable=False, location="json") ApiKeyAuthService.validate_api_key_auth_args(data)
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
try: try:
ApiKeyAuthService.create_provider_auth(current_tenant_id, args) ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
except Exception as e: except Exception as e:
raise ApiKeyAuthFailedError(str(e)) raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -5,12 +5,11 @@ from flask import current_app, redirect, request
from flask_restx import Resource, fields from flask_restx import Resource, fields
from configs import dify_config from configs import dify_config
from controllers.console import console_ns
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import login_required from libs.login import login_required
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required from .. import console_ns
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,5 +1,6 @@
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -14,16 +15,45 @@ from controllers.console.auth.error import (
InvalidTokenError, InvalidTokenError,
PasswordMismatchError, PasswordMismatchError,
) )
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models import Account from models import Account
from services.account_service import AccountService from services.account_service import AccountService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.account import AccountNotFoundError, AccountRegisterError
from ..error import AccountInFreezeError, EmailSendIpLimitError
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class EmailRegisterSendPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
language: str | None = Field(default=None, description="Language code")
class EmailRegisterValidityPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class EmailRegisterResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/email-register/send-email") @console_ns.route("/email-register/send-email")
class EmailRegisterSendEmailApi(Resource): class EmailRegisterSendEmailApi(Resource):
@ -31,27 +61,22 @@ class EmailRegisterSendEmailApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = ( args = EmailRegisterSendPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
language = "en-US" language = "en-US"
if args["language"] in languages: if args.language in languages:
language = args["language"] language = args.language
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError() raise AccountInFreezeError()
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = None token = None
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language) token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
return {"result": "success", "data": token} return {"result": "success", "data": token}
@ -61,40 +86,34 @@ class EmailRegisterCheckApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = ( args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
user_email = args["email"] user_email = args.email
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"]) is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
if is_email_register_error_rate_limit: if is_email_register_error_rate_limit:
raise EmailRegisterLimitError() raise EmailRegisterLimitError()
token_data = AccountService.get_email_register_data(args["token"]) token_data = AccountService.get_email_register_data(args.token)
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if user_email != token_data.get("email"): if user_email != token_data.get("email"):
raise InvalidEmailError() raise InvalidEmailError()
if args["code"] != token_data.get("code"): if args.code != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args["email"]) AccountService.add_email_register_error_rate_limit(args.email)
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token # Verified, revoke the first token
AccountService.revoke_email_register_token(args["token"]) AccountService.revoke_email_register_token(args.token)
# Refresh token data by generating a new token # Refresh token data by generating a new token
_, new_token = AccountService.generate_email_register_token( _, new_token = AccountService.generate_email_register_token(
user_email, code=args["code"], additional_data={"phase": "register"} user_email, code=args.code, additional_data={"phase": "register"}
) )
AccountService.reset_email_register_error_rate_limit(args["email"]) AccountService.reset_email_register_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@ -104,20 +123,14 @@ class EmailRegisterResetApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = ( args = EmailRegisterResetPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# Validate passwords match # Validate passwords match
if args["new_password"] != args["password_confirm"]: if args.new_password != args.password_confirm:
raise PasswordMismatchError() raise PasswordMismatchError()
# Validate token and get register data # Validate token and get register data
register_data = AccountService.get_email_register_data(args["token"]) register_data = AccountService.get_email_register_data(args.token)
if not register_data: if not register_data:
raise InvalidTokenError() raise InvalidTokenError()
# Must use token in reset phase # Must use token in reset phase
@ -125,7 +138,7 @@ class EmailRegisterResetApi(Resource):
raise InvalidTokenError() raise InvalidTokenError()
# Revoke token to prevent reuse # Revoke token to prevent reuse
AccountService.revoke_email_register_token(args["token"]) AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "") email = register_data.get("email", "")
@ -135,7 +148,7 @@ class EmailRegisterResetApi(Resource):
if account: if account:
raise EmailAlreadyInUseError() raise EmailAlreadyInUseError()
else: else:
account = self._create_new_account(email, args["password_confirm"]) account = self._create_new_account(email, args.password_confirm)
if not account: if not account:
raise AccountNotFoundError() raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))

View File

@ -2,7 +2,8 @@ import base64
import secrets import secrets
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -18,26 +19,46 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, setup_required from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models import Account from models import Account
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/forgot-password") @console_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordSendEmailApi(Resource):
@console_ns.doc("send_forgot_password_email") @console_ns.doc("send_forgot_password_email")
@console_ns.doc(description="Send password reset email") @console_ns.doc(description="Send password reset email")
@console_ns.expect( @console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
console_ns.model(
"ForgotPasswordEmailRequest",
{
"email": fields.String(required=True, description="Email address"),
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
},
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Email sent successfully", "Email sent successfully",
@ -54,28 +75,23 @@ class ForgotPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = ( args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans": if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans" language = "zh-Hans"
else: else:
language = "en-US" language = "en-US"
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = AccountService.send_reset_password_email( token = AccountService.send_reset_password_email(
account=account, account=account,
email=args["email"], email=args.email,
language=language, language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register, is_allow_register=FeatureService.get_system_features().is_allow_register,
) )
@ -87,16 +103,7 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource): class ForgotPasswordCheckApi(Resource):
@console_ns.doc("check_forgot_password_code") @console_ns.doc("check_forgot_password_code")
@console_ns.doc(description="Verify password reset code") @console_ns.doc(description="Verify password reset code")
@console_ns.expect( @console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
console_ns.model(
"ForgotPasswordCheckRequest",
{
"email": fields.String(required=True, description="Email address"),
"code": fields.String(required=True, description="Verification code"),
"token": fields.String(required=True, description="Reset token"),
},
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Code verified successfully", "Code verified successfully",
@ -113,40 +120,34 @@ class ForgotPasswordCheckApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = ( args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
user_email = args["email"] user_email = args.email
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
if is_forgot_password_error_rate_limit: if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError() raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"]) token_data = AccountService.get_reset_password_data(args.token)
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if user_email != token_data.get("email"): if user_email != token_data.get("email"):
raise InvalidEmailError() raise InvalidEmailError()
if args["code"] != token_data.get("code"): if args.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"]) AccountService.add_forgot_password_error_rate_limit(args.email)
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token # Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"]) AccountService.revoke_reset_password_token(args.token)
# Refresh token data by generating a new token # Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token( _, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"} user_email, code=args.code, additional_data={"phase": "reset"}
) )
AccountService.reset_forgot_password_error_rate_limit(args["email"]) AccountService.reset_forgot_password_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@ -154,16 +155,7 @@ class ForgotPasswordCheckApi(Resource):
class ForgotPasswordResetApi(Resource): class ForgotPasswordResetApi(Resource):
@console_ns.doc("reset_password") @console_ns.doc("reset_password")
@console_ns.doc(description="Reset password with verification token") @console_ns.doc(description="Reset password with verification token")
@console_ns.expect( @console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
console_ns.model(
"ForgotPasswordResetRequest",
{
"token": fields.String(required=True, description="Verification token"),
"new_password": fields.String(required=True, description="New password"),
"password_confirm": fields.String(required=True, description="Password confirmation"),
},
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Password reset successfully", "Password reset successfully",
@ -173,20 +165,14 @@ class ForgotPasswordResetApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = ( args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# Validate passwords match # Validate passwords match
if args["new_password"] != args["password_confirm"]: if args.new_password != args.password_confirm:
raise PasswordMismatchError() raise PasswordMismatchError()
# Validate token and get reset data # Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args["token"]) reset_data = AccountService.get_reset_password_data(args.token)
if not reset_data: if not reset_data:
raise InvalidTokenError() raise InvalidTokenError()
# Must use token in reset phase # Must use token in reset phase
@ -194,11 +180,11 @@ class ForgotPasswordResetApi(Resource):
raise InvalidTokenError() raise InvalidTokenError()
# Revoke token to prevent reuse # Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"]) AccountService.revoke_reset_password_token(args.token)
# Generate secure salt and hash password # Generate secure salt and hash password
salt = secrets.token_bytes(16) salt = secrets.token_bytes(16)
password_hashed = hash_password(args["new_password"], salt) password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "") email = reset_data.get("email", "")

View File

@ -1,6 +1,7 @@
import flask_login import flask_login
from flask import make_response, request from flask import make_response, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
import services import services
from configs import dify_config from configs import dify_config
@ -23,7 +24,7 @@ from controllers.console.error import (
) )
from controllers.console.wraps import email_password_login_enabled, setup_required from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from libs.token import ( from libs.token import (
clear_access_token_from_cookie, clear_access_token_from_cookie,
@ -40,6 +41,36 @@ from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class LoginPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., description="Password")
remember_me: bool = Field(default=False, description="Remember me flag")
invite_token: str | None = Field(default=None, description="Invitation token")
class EmailPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class EmailCodeLoginPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
language: str | None = Field(default=None)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(LoginPayload)
reg(EmailPayload)
reg(EmailCodeLoginPayload)
@console_ns.route("/login") @console_ns.route("/login")
class LoginApi(Resource): class LoginApi(Resource):
@ -47,41 +78,36 @@ class LoginApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@console_ns.expect(console_ns.models[LoginPayload.__name__])
def post(self): def post(self):
"""Authenticate user and login.""" """Authenticate user and login."""
parser = ( args = LoginPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("password", type=str, required=True, location="json")
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
.add_argument("invite_token", type=str, required=False, default=None, location="json")
)
args = parser.parse_args()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError() raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"]) is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
if is_login_error_rate_limit: if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError() raise EmailPasswordLoginLimitError()
invitation = args["invite_token"] # TODO: why invitation is re-assigned with different type?
invitation = args.invite_token # type: ignore
if invitation: if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation) invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
try: try:
if invitation: if invitation:
data = invitation.get("data", {}) data = invitation.get("data", {}) # type: ignore
invitee_email = data.get("email") if data else None invitee_email = data.get("email") if data else None
if invitee_email != args["email"]: if invitee_email != args.email:
raise InvalidEmailError() raise InvalidEmailError()
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"]) account = AccountService.authenticate(args.email, args.password, args.invite_token)
else: else:
account = AccountService.authenticate(args["email"], args["password"]) account = AccountService.authenticate(args.email, args.password)
except services.errors.account.AccountLoginError: except services.errors.account.AccountLoginError:
raise AccountBannedError() raise AccountBannedError()
except services.errors.account.AccountPasswordError: except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args["email"]) AccountService.add_login_error_rate_limit(args.email)
raise AuthenticationFailedError() raise AuthenticationFailedError()
# SELF_HOSTED only have one workspace # SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
@ -97,7 +123,7 @@ class LoginApi(Resource):
} }
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body # Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"}) response = make_response({"result": "success"})
@ -134,25 +160,21 @@ class LogoutApi(Resource):
class ResetPasswordSendEmailApi(Resource): class ResetPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self): def post(self):
parser = ( args = EmailPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
if args["language"] is not None and args["language"] == "zh-Hans": if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans" language = "zh-Hans"
else: else:
language = "en-US" language = "en-US"
try: try:
account = AccountService.get_user_through_email(args["email"]) account = AccountService.get_user_through_email(args.email)
except AccountRegisterError: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
token = AccountService.send_reset_password_email( token = AccountService.send_reset_password_email(
email=args["email"], email=args.email,
account=account, account=account,
language=language, language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register, is_allow_register=FeatureService.get_system_features().is_allow_register,
@ -164,30 +186,26 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.route("/email-code-login") @console_ns.route("/email-code-login")
class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginSendEmailApi(Resource):
@setup_required @setup_required
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self): def post(self):
parser = ( args = EmailPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans": if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans" language = "zh-Hans"
else: else:
language = "en-US" language = "en-US"
try: try:
account = AccountService.get_user_through_email(args["email"]) account = AccountService.get_user_through_email(args.email)
except AccountRegisterError: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
if account is None: if account is None:
if FeatureService.get_system_features().is_allow_register: if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args["email"], language=language) token = AccountService.send_email_code_login_email(email=args.email, language=language)
else: else:
raise AccountNotFound() raise AccountNotFound()
else: else:
@ -199,30 +217,24 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.route("/email-code-login/validity") @console_ns.route("/email-code-login/validity")
class EmailCodeLoginApi(Resource): class EmailCodeLoginApi(Resource):
@setup_required @setup_required
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
def post(self): def post(self):
parser = ( args = EmailCodeLoginPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
user_email = args["email"] user_email = args.email
language = args["language"] language = args.language
token_data = AccountService.get_email_code_login_data(args["token"]) token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if token_data["email"] != args["email"]: if token_data["email"] != args.email:
raise InvalidEmailError() raise InvalidEmailError()
if token_data["code"] != args["code"]: if token_data["code"] != args.code:
raise EmailCodeError() raise EmailCodeError()
AccountService.revoke_email_code_login_token(args["token"]) AccountService.revoke_email_code_login_token(args.token)
try: try:
account = AccountService.get_user_through_email(user_email) account = AccountService.get_user_through_email(user_email)
except AccountRegisterError: except AccountRegisterError:
@ -255,7 +267,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError: except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded() raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body # Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"}) response = make_response({"result": "success"})

View File

@ -3,7 +3,8 @@ from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar from typing import Concatenate, ParamSpec, TypeVar
from flask import jsonify, request from flask import jsonify, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
@ -20,15 +21,34 @@ R = TypeVar("R")
T = TypeVar("T") T = TypeVar("T")
class OAuthClientPayload(BaseModel):
client_id: str
class OAuthProviderRequest(BaseModel):
client_id: str
redirect_uri: str
class OAuthTokenRequest(BaseModel):
client_id: str
grant_type: str
code: str | None = None
client_secret: str | None = None
redirect_uri: str | None = None
refresh_token: str | None = None
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view) @wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs): def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json") json_data = request.get_json()
parsed_args = parser.parse_args() if json_data is None:
client_id = parsed_args.get("client_id")
if not client_id:
raise BadRequest("client_id is required") raise BadRequest("client_id is required")
payload = OAuthClientPayload.model_validate(json_data)
client_id = payload.client_id
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id) oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
if not oauth_provider_app: if not oauth_provider_app:
raise NotFound("client_id is invalid") raise NotFound("client_id is invalid")
@ -89,9 +109,8 @@ class OAuthServerAppApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json") payload = OAuthProviderRequest.model_validate(request.get_json())
parsed_args = parser.parse_args() redirect_uri = payload.redirect_uri
redirect_uri = parsed_args.get("redirect_uri")
# check if redirect_uri is valid # check if redirect_uri is valid
if redirect_uri not in oauth_provider_app.redirect_uris: if redirect_uri not in oauth_provider_app.redirect_uris:
@ -130,33 +149,25 @@ class OAuthServerUserTokenApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
parser = ( payload = OAuthTokenRequest.model_validate(request.get_json())
reqparse.RequestParser()
.add_argument("grant_type", type=str, required=True, location="json")
.add_argument("code", type=str, required=False, location="json")
.add_argument("client_secret", type=str, required=False, location="json")
.add_argument("redirect_uri", type=str, required=False, location="json")
.add_argument("refresh_token", type=str, required=False, location="json")
)
parsed_args = parser.parse_args()
try: try:
grant_type = OAuthGrantType(parsed_args["grant_type"]) grant_type = OAuthGrantType(payload.grant_type)
except ValueError: except ValueError:
raise BadRequest("invalid grant_type") raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE: if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]: if not payload.code:
raise BadRequest("code is required") raise BadRequest("code is required")
if parsed_args["client_secret"] != oauth_provider_app.client_secret: if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid") raise BadRequest("client_secret is invalid")
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris: if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid") raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token( access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id grant_type, code=payload.code, client_id=oauth_provider_app.client_id
) )
return jsonable_encoder( return jsonable_encoder(
{ {
@ -167,11 +178,11 @@ class OAuthServerUserTokenApi(Resource):
} }
) )
elif grant_type == OAuthGrantType.REFRESH_TOKEN: elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not parsed_args["refresh_token"]: if not payload.refresh_token:
raise BadRequest("refresh_token is required") raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token( access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
) )
return jsonable_encoder( return jsonable_encoder(
{ {

View File

@ -1,6 +1,8 @@
import base64 import base64
from flask_restx import Resource, fields, reqparse from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from controllers.console import console_ns from controllers.console import console_ns
@ -9,6 +11,35 @@ from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService from services.billing_service import BillingService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel):
plan: str = Field(..., description="Subscription plan")
interval: str = Field(..., description="Billing interval")
@field_validator("plan")
@classmethod
def validate_plan(cls, value: str) -> str:
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
raise ValueError("Invalid plan")
return value
@field_validator("interval")
@classmethod
def validate_interval(cls, value: str) -> str:
if value not in {"month", "year"}:
raise ValueError("Invalid interval")
return value
class PartnerTenantsPayload(BaseModel):
click_id: str = Field(..., description="Click Id from partner referral link")
for model in (SubscriptionQuery, PartnerTenantsPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/billing/subscription") @console_ns.route("/billing/subscription")
class Subscription(Resource): class Subscription(Resource):
@ -18,20 +49,9 @@ class Subscription(Resource):
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = ( args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument(
"plan",
type=str,
required=True,
location="args",
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
)
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
)
args = parser.parse_args()
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id) return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
@console_ns.route("/billing/invoices") @console_ns.route("/billing/invoices")
@ -65,11 +85,10 @@ class PartnerTenants(Resource):
@only_edition_cloud @only_edition_cloud
def put(self, partner_key: str): def put(self, partner_key: str):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
args = parser.parse_args()
try: try:
click_id = args["click_id"] args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
click_id = args.click_id
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8") decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
except Exception: except Exception:
raise BadRequest("Invalid partner_key") raise BadRequest("Invalid partner_key")

View File

@ -1,5 +1,6 @@
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
@ -9,16 +10,28 @@ from .. import console_ns
from ..wraps import account_initialization_required, only_edition_cloud, setup_required from ..wraps import account_initialization_required, only_edition_cloud, setup_required
class ComplianceDownloadQuery(BaseModel):
doc_name: str = Field(..., description="Compliance document name")
console_ns.schema_model(
ComplianceDownloadQuery.__name__,
ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/compliance/download") @console_ns.route("/compliance/download")
class ComplianceApi(Resource): class ComplianceApi(Resource):
@console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__])
@console_ns.doc("download_compliance_document")
@console_ns.doc(description="Get compliance document download link")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args") args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device") device_info = request.headers.get("User-Agent", "Unknown device")

View File

@ -1,4 +1,6 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants.languages import languages from constants.languages import languages
from controllers.console import console_ns from controllers.console import console_ns
@ -35,20 +37,26 @@ recommended_app_list_fields = {
} }
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args") class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None)
console_ns.schema_model(
RecommendedAppsQuery.__name__,
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/explore/apps") @console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource): class RecommendedAppListApi(Resource):
@console_ns.expect(parser_apps) @console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(recommended_app_list_fields) @marshal_with(recommended_app_list_fields)
def get(self): def get(self):
# language args # language args
args = parser_apps.parse_args() args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
language = args.language
language = args.get("language")
if language and language in languages: if language and language in languages:
language_prefix = language language_prefix = language
elif current_user and current_user.interface_language: elif current_user and current_user.interface_language:

View File

@ -1,13 +1,13 @@
import os import os
from flask import session from flask import session
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import StrLen
from models.model import DifySetup from models.model import DifySetup
from services.account_service import TenantService from services.account_service import TenantService
@ -15,6 +15,18 @@ from . import console_ns
from .error import AlreadySetupError, InitValidateFailedError from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InitValidatePayload(BaseModel):
password: str = Field(..., max_length=30)
console_ns.schema_model(
InitValidatePayload.__name__,
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/init") @console_ns.route("/init")
class InitValidateAPI(Resource): class InitValidateAPI(Resource):
@ -37,12 +49,7 @@ class InitValidateAPI(Resource):
@console_ns.doc("validate_init_password") @console_ns.doc("validate_init_password")
@console_ns.doc(description="Validate initialization password for self-hosted edition") @console_ns.doc(description="Validate initialization password for self-hosted edition")
@console_ns.expect( @console_ns.expect(console_ns.models[InitValidatePayload.__name__])
console_ns.model(
"InitValidateRequest",
{"password": fields.String(required=True, description="Initialization password", max_length=30)},
)
)
@console_ns.response( @console_ns.response(
201, 201,
"Success", "Success",
@ -57,8 +64,8 @@ class InitValidateAPI(Resource):
if tenant_count > 0: if tenant_count > 0:
raise AlreadySetupError() raise AlreadySetupError()
parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json") payload = InitValidatePayload.model_validate(console_ns.payload)
input_password = parser.parse_args()["password"] input_password = payload.password
if input_password != os.environ.get("INIT_PASSWORD"): if input_password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False session["is_init_validated"] = False

View File

@ -1,7 +1,8 @@
import urllib.parse import urllib.parse
import httpx import httpx
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
import services import services
from controllers.common import helpers from controllers.common import helpers
@ -36,17 +37,23 @@ class RemoteFileInfoApi(Resource):
} }
parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required") class RemoteFileUploadPayload(BaseModel):
url: str = Field(..., description="URL to fetch")
console_ns.schema_model(
RemoteFileUploadPayload.__name__,
RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/remote-files/upload") @console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource): class RemoteFileUploadApi(Resource):
@console_ns.expect(parser_upload) @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@marshal_with(file_fields_with_signed_url) @marshal_with(file_fields_with_signed_url)
def post(self): def post(self):
args = parser_upload.parse_args() args = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = args.url
url = args["url"]
try: try:
resp = ssrf_proxy.head(url=url) resp = ssrf_proxy.head(url=url)

View File

@ -1,8 +1,9 @@
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from configs import dify_config from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.model import DifySetup, db from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
@ -12,6 +13,26 @@ from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SetupRequestPayload(BaseModel):
email: EmailStr = Field(..., description="Admin email address")
name: str = Field(..., max_length=30, description="Admin name (max 30 characters)")
password: str = Field(..., description="Admin password")
language: str | None = Field(default=None, description="Admin language")
@field_validator("password")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
console_ns.schema_model(
SetupRequestPayload.__name__,
SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/setup") @console_ns.route("/setup")
class SetupApi(Resource): class SetupApi(Resource):
@ -42,17 +63,7 @@ class SetupApi(Resource):
@console_ns.doc("setup_system") @console_ns.doc("setup_system")
@console_ns.doc(description="Initialize system setup with admin account") @console_ns.doc(description="Initialize system setup with admin account")
@console_ns.expect( @console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
console_ns.model(
"SetupRequest",
{
"email": fields.String(required=True, description="Admin email address"),
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
"password": fields.String(required=True, description="Admin password"),
"language": fields.String(required=False, description="Admin language"),
},
)
)
@console_ns.response( @console_ns.response(
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")}) 201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
) )
@ -72,22 +83,15 @@ class SetupApi(Resource):
if not get_init_validate_status(): if not get_init_validate_status():
raise NotInitValidateError() raise NotInitValidateError()
parser = ( args = SetupRequestPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("name", type=StrLen(30), required=True, location="json")
.add_argument("password", type=valid_password, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
# setup # setup
RegisterService.setup( RegisterService.setup(
email=args["email"], email=args.email,
name=args["name"], name=args.name,
password=args["password"], password=args.password,
ip_address=extract_remote_ip(request), ip_address=extract_remote_ip(request),
language=args["language"], language=args.language,
) )
return {"result": "success"}, 201 return {"result": "success"}, 201

View File

@ -2,8 +2,10 @@ import json
import logging import logging
import httpx import httpx
from flask_restx import Resource, fields, reqparse from flask import request
from flask_restx import Resource, fields
from packaging import version from packaging import version
from pydantic import BaseModel, Field
from configs import dify_config from configs import dify_config
@ -11,8 +13,14 @@ from . import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
parser = reqparse.RequestParser().add_argument(
"current_version", type=str, required=True, location="args", help="Current application version" class VersionQuery(BaseModel):
current_version: str = Field(..., description="Current application version")
console_ns.schema_model(
VersionQuery.__name__,
VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
) )
@ -20,7 +28,7 @@ parser = reqparse.RequestParser().add_argument(
class VersionApi(Resource): class VersionApi(Resource):
@console_ns.doc("check_version_update") @console_ns.doc("check_version_update")
@console_ns.doc(description="Check for application version updates") @console_ns.doc(description="Check for application version updates")
@console_ns.expect(parser) @console_ns.expect(console_ns.models[VersionQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Success", "Success",
@ -37,7 +45,7 @@ class VersionApi(Resource):
) )
def get(self): def get(self):
"""Check for application version updates""" """Check for application version updates"""
args = parser.parse_args() args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
check_update_url = dify_config.CHECK_UPDATE_URL check_update_url = dify_config.CHECK_UPDATE_URL
result = { result = {
@ -57,16 +65,16 @@ class VersionApi(Resource):
try: try:
response = httpx.get( response = httpx.get(
check_update_url, check_update_url,
params={"current_version": args["current_version"]}, params={"current_version": args.current_version},
timeout=httpx.Timeout(timeout=10.0, connect=3.0), timeout=httpx.Timeout(timeout=10.0, connect=3.0),
) )
except Exception as error: except Exception as error:
logger.warning("Check update version error: %s.", str(error)) logger.warning("Check update version error: %s.", str(error))
result["version"] = args["current_version"] result["version"] = args.current_version
return result return result
content = json.loads(response.content) content = json.loads(response.content)
if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"): if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
result["version"] = content["version"] result["version"] = content["version"]
result["release_date"] = content["releaseDate"] result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"] result["release_notes"] = content["releaseNotes"]

View File

@ -37,7 +37,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db from extensions.ext_database import db
from fields.member_fields import account_fields from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, email, extract_remote_ip, timezone from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Account, AccountIntegrate, InvitationCode from models import Account, AccountIntegrate, InvitationCode
from services.account_service import AccountService from services.account_service import AccountService
@ -111,14 +111,9 @@ class AccountDeletePayload(BaseModel):
class AccountDeletionFeedbackPayload(BaseModel): class AccountDeletionFeedbackPayload(BaseModel):
email: str email: EmailStr
feedback: str feedback: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class EducationActivatePayload(BaseModel): class EducationActivatePayload(BaseModel):
token: str token: str
@ -133,45 +128,25 @@ class EducationAutocompleteQuery(BaseModel):
class ChangeEmailSendPayload(BaseModel): class ChangeEmailSendPayload(BaseModel):
email: str email: EmailStr
language: str | None = None language: str | None = None
phase: str | None = None phase: str | None = None
token: str | None = None token: str | None = None
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailValidityPayload(BaseModel): class ChangeEmailValidityPayload(BaseModel):
email: str email: EmailStr
code: str code: str
token: str token: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailResetPayload(BaseModel): class ChangeEmailResetPayload(BaseModel):
new_email: str new_email: EmailStr
token: str token: str
@field_validator("new_email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class CheckEmailUniquePayload(BaseModel): class CheckEmailUniquePayload(BaseModel):
email: str email: EmailStr
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
def reg(cls: type[BaseModel]): def reg(cls: type[BaseModel]):

View File

@ -1,7 +1,8 @@
from urllib.parse import quote from urllib.parse import quote
from flask import Response, request from flask import Response, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services import services
@ -11,6 +12,26 @@ from extensions.ext_database import db
from services.account_service import TenantService from services.account_service import TenantService
from services.file_service import FileService from services.file_service import FileService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class FileSignatureQuery(BaseModel):
timestamp: str = Field(..., description="Unix timestamp used in the signature")
nonce: str = Field(..., description="Random string for signature")
sign: str = Field(..., description="HMAC signature")
class FilePreviewQuery(FileSignatureQuery):
as_attachment: bool = Field(default=False, description="Whether to download as attachment")
files_ns.schema_model(
FileSignatureQuery.__name__, FileSignatureQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
files_ns.schema_model(
FilePreviewQuery.__name__, FilePreviewQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@files_ns.route("/<uuid:file_id>/image-preview") @files_ns.route("/<uuid:file_id>/image-preview")
class ImagePreviewApi(Resource): class ImagePreviewApi(Resource):
@ -36,12 +57,10 @@ class ImagePreviewApi(Resource):
def get(self, file_id): def get(self, file_id):
file_id = str(file_id) file_id = str(file_id)
timestamp = request.args.get("timestamp") args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
nonce = request.args.get("nonce") timestamp = args.timestamp
sign = request.args.get("sign") nonce = args.nonce
sign = args.sign
if not timestamp or not nonce or not sign:
return {"content": "Invalid request."}, 400
try: try:
generator, mimetype = FileService(db.engine).get_image_preview( generator, mimetype = FileService(db.engine).get_image_preview(
@ -80,25 +99,14 @@ class FilePreviewApi(Resource):
def get(self, file_id): def get(self, file_id):
file_id = str(file_id) file_id = str(file_id)
parser = ( args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("timestamp", type=str, required=True, location="args")
.add_argument("nonce", type=str, required=True, location="args")
.add_argument("sign", type=str, required=True, location="args")
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args()
if not args["timestamp"] or not args["nonce"] or not args["sign"]:
return {"content": "Invalid request."}, 400
try: try:
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id( generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
file_id=file_id, file_id=file_id,
timestamp=args["timestamp"], timestamp=args.timestamp,
nonce=args["nonce"], nonce=args.nonce,
sign=args["sign"], sign=args.sign,
) )
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
@ -125,7 +133,7 @@ class FilePreviewApi(Resource):
response.headers["Accept-Ranges"] = "bytes" response.headers["Accept-Ranges"] = "bytes"
if upload_file.size > 0: if upload_file.size > 0:
response.headers["Content-Length"] = str(upload_file.size) response.headers["Content-Length"] = str(upload_file.size)
if args["as_attachment"]: if args.as_attachment:
encoded_filename = quote(upload_file.name) encoded_filename = quote(upload_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/octet-stream" response.headers["Content-Type"] = "application/octet-stream"

View File

@ -1,7 +1,8 @@
from urllib.parse import quote from urllib.parse import quote
from flask import Response from flask import Response, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from controllers.common.errors import UnsupportedFileTypeError from controllers.common.errors import UnsupportedFileTypeError
@ -10,6 +11,20 @@ from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db as global_db from extensions.ext_database import db as global_db
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ToolFileQuery(BaseModel):
timestamp: str = Field(..., description="Unix timestamp")
nonce: str = Field(..., description="Random nonce")
sign: str = Field(..., description="HMAC signature")
as_attachment: bool = Field(default=False, description="Download as attachment")
files_ns.schema_model(
ToolFileQuery.__name__, ToolFileQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@files_ns.route("/tools/<uuid:file_id>.<string:extension>") @files_ns.route("/tools/<uuid:file_id>.<string:extension>")
class ToolFileApi(Resource): class ToolFileApi(Resource):
@ -36,18 +51,8 @@ class ToolFileApi(Resource):
def get(self, file_id, extension): def get(self, file_id, extension):
file_id = str(file_id) file_id = str(file_id)
parser = ( args = ToolFileQuery.model_validate(request.args.to_dict())
reqparse.RequestParser() if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign):
.add_argument("timestamp", type=str, required=True, location="args")
.add_argument("nonce", type=str, required=True, location="args")
.add_argument("sign", type=str, required=True, location="args")
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args()
if not verify_tool_file_signature(
file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"]
):
raise Forbidden("Invalid request.") raise Forbidden("Invalid request.")
try: try:
@ -69,7 +74,7 @@ class ToolFileApi(Resource):
) )
if tool_file.size > 0: if tool_file.size > 0:
response.headers["Content-Length"] = str(tool_file.size) response.headers["Content-Length"] = str(tool_file.size)
if args["as_attachment"]: if args.as_attachment:
encoded_filename = quote(tool_file.name) encoded_filename = quote(tool_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"

View File

@ -1,40 +1,45 @@
from mimetypes import guess_extension from mimetypes import guess_extension
from flask_restx import Resource, reqparse from flask import request
from flask_restx import Resource
from flask_restx.api import HTTPStatus from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
import services import services
from controllers.common.errors import (
FileTooLargeError,
UnsupportedFileTypeError,
)
from controllers.console.wraps import setup_required
from controllers.files import files_ns
from controllers.inner_api.plugin.wraps import get_user
from core.file.helpers import verify_plugin_file_signature from core.file.helpers import verify_plugin_file_signature
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from fields.file_fields import build_file_model from fields.file_fields import build_file_model
# Define parser for both documentation and validation from ..common.errors import (
upload_parser = ( FileTooLargeError,
reqparse.RequestParser() UnsupportedFileTypeError,
.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload") )
.add_argument( from ..console.wraps import setup_required
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification" from ..files import files_ns
) from ..inner_api.plugin.wraps import get_user
.add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification")
.add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation") DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
class PluginUploadQuery(BaseModel):
timestamp: str = Field(..., description="Unix timestamp for signature verification")
nonce: str = Field(..., description="Random nonce for signature verification")
sign: str = Field(..., description="HMAC signature")
tenant_id: str = Field(..., description="Tenant identifier")
user_id: str | None = Field(default=None, description="User identifier")
files_ns.schema_model(
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
) )
@files_ns.route("/upload/for-plugin") @files_ns.route("/upload/for-plugin")
class PluginUploadFileApi(Resource): class PluginUploadFileApi(Resource):
@setup_required @setup_required
@files_ns.expect(upload_parser) @files_ns.expect(files_ns.models[PluginUploadQuery.__name__])
@files_ns.doc("upload_plugin_file") @files_ns.doc("upload_plugin_file")
@files_ns.doc(description="Upload a file for plugin usage with signature verification") @files_ns.doc(description="Upload a file for plugin usage with signature verification")
@files_ns.doc( @files_ns.doc(
@ -62,15 +67,17 @@ class PluginUploadFileApi(Resource):
FileTooLargeError: File exceeds size limit FileTooLargeError: File exceeds size limit
UnsupportedFileTypeError: File type not supported UnsupportedFileTypeError: File type not supported
""" """
# Parse and validate all arguments args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = upload_parser.parse_args()
file: FileStorage = args["file"] file: FileStorage | None = request.files.get("file")
timestamp: str = args["timestamp"] if file is None:
nonce: str = args["nonce"] raise Forbidden("File is required.")
sign: str = args["sign"]
tenant_id: str = args["tenant_id"] timestamp = args.timestamp
user_id: str | None = args.get("user_id") nonce = args.nonce
sign = args.sign
tenant_id = args.tenant_id
user_id = args.user_id
user = get_user(tenant_id, user_id) user = get_user(tenant_id, user_id)
filename: str | None = file.filename filename: str | None = file.filename

View File

@ -256,7 +256,7 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
now = datetime_utils.naive_utc_now() now = datetime_utils.naive_utc_now()
last_update = _get_last_update_timestamp(cache_key) last_update = _get_last_update_timestamp(cache_key)
if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS: if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS: # type: ignore
update_values["last_used"] = values.last_used update_values["last_used"] = values.last_used
_set_last_update_timestamp(cache_key, now) _set_last_update_timestamp(cache_key, now)

View File

@ -3,7 +3,7 @@ import logging
import ssl import ssl
from collections.abc import Callable from collections.abc import Callable
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Any, Union from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union
import redis import redis
from redis import RedisError from redis import RedisError
@ -245,7 +245,12 @@ def init_app(app: DifyApp):
app.extensions["redis"] = redis_client app.extensions["redis"] = redis_client
def redis_fallback(default_return: Any | None = None): P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def redis_fallback(default_return: T | None = None): # type: ignore
""" """
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable. decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
@ -253,9 +258,9 @@ def redis_fallback(default_return: Any | None = None):
default_return: The value to return when a Redis operation fails. Defaults to None. default_return: The value to return when a Redis operation fails. Defaults to None.
""" """
def decorator(func: Callable): def decorator(func: Callable[P, R]):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args: P.args, **kwargs: P.kwargs):
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
except RedisError as e: except RedisError as e:

View File

@ -10,12 +10,13 @@ import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from datetime import datetime from datetime import datetime
from hashlib import sha256 from hashlib import sha256
from typing import TYPE_CHECKING, Any, Optional, Union, cast from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
from zoneinfo import available_timezones from zoneinfo import available_timezones
from flask import Response, stream_with_context from flask import Response, stream_with_context
from flask_restx import fields from flask_restx import fields
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.functional_validators import AfterValidator
from configs import dify_config from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
@ -103,6 +104,9 @@ def email(email):
raise ValueError(error) raise ValueError(error)
EmailStr = Annotated[str, AfterValidator(email)]
def uuid_value(value): def uuid_value(value):
if value == "": if value == "":
return str(value) return str(value)

10
api/pyrefly.toml Normal file
View File

@ -0,0 +1,10 @@
project-includes = ["."]
project-excludes = [
"tests/",
".venv",
"migrations/",
"core/rag",
]
python-platform = "linux"
python-version = "3.11.0"
infer-with-first-use = false

View File

@ -1259,7 +1259,7 @@ class RegisterService:
return f"member_invite:token:{token}" return f"member_invite:token:{token}"
@classmethod @classmethod
def setup(cls, email: str, name: str, password: str, ip_address: str, language: str): def setup(cls, email: str, name: str, password: str, ip_address: str, language: str | None):
""" """
Setup dify Setup dify
@ -1267,6 +1267,7 @@ class RegisterService:
:param name: username :param name: username
:param password: password :param password: password
:param ip_address: ip address :param ip_address: ip address
:param language: language
""" """
try: try:
account = AccountService.create_account( account = AccountService.create_account(
@ -1414,7 +1415,7 @@ class RegisterService:
return data is not None return data is not None
@classmethod @classmethod
def revoke_token(cls, workspace_id: str, email: str, token: str): def revoke_token(cls, workspace_id: str | None, email: str | None, token: str):
if workspace_id and email: if workspace_id and email:
email_hash = sha256(email.encode()).hexdigest() email_hash = sha256(email.encode()).hexdigest()
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
@ -1423,7 +1424,9 @@ class RegisterService:
redis_client.delete(cls._get_invitation_token_key(token)) redis_client.delete(cls._get_invitation_token_key(token))
@classmethod @classmethod
def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None: def get_invitation_if_token_valid(
cls, workspace_id: str | None, email: str | None, token: str
) -> dict[str, Any] | None:
invitation_data = cls.get_invitation_by_token(token, workspace_id, email) invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
if not invitation_data: if not invitation_data:
return None return None